Skip to content

Commit a1be911

Browse files
committed
MNT rename functions and add tests (#31)
1 parent 46ddb76 commit a1be911

File tree

4 files changed

+65
-58
lines changed

4 files changed

+65
-58
lines changed

docs/api.rst

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@ This is the full API documentation of the `mlresearch` package.
3838
data_augmentation.GeometricSMOTE
3939
data_augmentation.OverSamplingAugmentation
4040

41-
:mod:`mlresearch.preprocessing`
42-
-------------------------------
43-
44-
.. automodule:: mlresearch.preprocessing
45-
:no-members:
46-
:no-inherited-members:
47-
48-
.. currentmodule:: mlresearch
49-
50-
.. autosummary::
51-
:toctree: _generated/
52-
:template: class.rst
53-
54-
preprocessing.PipelineEncoder
5541

5642
:mod:`mlresearch.datasets`
5743
--------------------------
@@ -72,6 +58,25 @@ This is the full API documentation of the `mlresearch` package.
7258
datasets.MultiClassDatasets
7359
datasets.RemoteSensingDatasets
7460

61+
:mod:`mlresearch.latex`
62+
-----------------------
63+
64+
.. automodule:: mlresearch.latex
65+
:no-members:
66+
:no-inherited-members:
67+
68+
.. currentmodule:: mlresearch
69+
70+
.. autosummary::
71+
:toctree: _generated/
72+
:template: function.rst
73+
74+
latex.format_table
75+
latex.make_bold
76+
latex.make_mean_sem_table
77+
latex.export_longtable
78+
79+
7580
:mod:`mlresearch.metrics`
7681
-------------------------
7782

@@ -97,23 +102,20 @@ This is the full API documentation of the `mlresearch` package.
97102

98103
metrics.ALScorer
99104

100-
:mod:`mlresearch.latex`
101-
-----------------------
105+
:mod:`mlresearch.preprocessing`
106+
-------------------------------
102107

103-
.. automodule:: mlresearch.latex
104-
:no-members:
105-
:no-inherited-members:
108+
.. automodule:: mlresearch.preprocessing
109+
:no-members:
110+
:no-inherited-members:
106111

107112
.. currentmodule:: mlresearch
108113

109114
.. autosummary::
110-
:toctree: _generated/
111-
:template: function.rst
115+
:toctree: _generated/
116+
:template: class.rst
112117

113-
latex.format_table
114-
latex.make_bold
115-
latex.make_mean_sem_table
116-
latex.export_longtable
118+
preprocessing.PipelineEncoder
117119

118120
:mod:`mlresearch.utils`
119121
-----------------------
@@ -133,6 +135,6 @@ This is the full API documentation of the `mlresearch` package.
133135
utils.load_datasets
134136
utils.check_pipelines
135137
utils.check_pipelines_wrapper
136-
utils.load_plt_sns_configs
137-
utils.val_to_color
138+
utils.set_matplotlib_style
139+
utils.feature_to_color
138140
utils.generate_paths

mlresearch/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ._image import image_to_dataframe, dataframe_to_image
66
from ._data import load_datasets
77
from ._check_pipelines import check_pipelines, check_pipelines_wrapper
8-
from ._visualization import load_plt_sns_configs, val_to_color
8+
from ._visualization import set_matplotlib_style, feature_to_color
99
from ._utils import generate_paths
1010

1111
__all__ = [
@@ -14,7 +14,7 @@
1414
"load_datasets",
1515
"check_pipelines",
1616
"check_pipelines_wrapper",
17-
"load_plt_sns_configs",
18-
"val_to_color",
17+
"set_matplotlib_style",
18+
"feature_to_color",
1919
"generate_paths",
2020
]

mlresearch/utils/_visualization.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,20 @@ def _optional_import(module: str) -> types.ModuleType:
3838
return module_
3939

4040

41-
def load_plt_sns_configs(font_size=8, **rcparams):
41+
def set_matplotlib_style(font_size=8, **rcparams):
4242
"""
43-
Load LaTeX style configurations for Matplotlib Visualizations.
43+
Load LaTeX-style configurations for Matplotlib Visualizations.
4444
"""
4545
plt = _optional_import("matplotlib.pyplot")
4646

4747
# Replicates the rcParams of seaborn's "whitegrid" style and a few extra
4848
# configurations I like
49+
plt.style.use("seaborn-v0_8-whitegrid")
4950
base_style = {
50-
# Whitegrid
51-
"axes.axisbelow": True,
52-
"axes.edgecolor": ".8",
53-
"axes.grid": True,
54-
"axes.labelcolor": ".15",
55-
"font.sans-serif": [
56-
"Arial",
57-
"DejaVu Sans",
58-
"Liberation Sans",
59-
"Bitstream Vera Sans",
60-
"sans-serif",
61-
],
62-
"grid.color": ".8",
63-
"image.cmap": "rocket",
64-
"lines.solid_capstyle": "round",
65-
"patch.edgecolor": "w",
66-
"patch.force_edgecolor": True,
67-
"text.color": ".15",
68-
"xtick.bottom": False,
69-
"xtick.color": ".15",
70-
"ytick.color": ".15",
71-
"ytick.left": False,
72-
# Extras
51+
# "patch.edgecolor": "w",
52+
# "patch.force_edgecolor": True,
53+
# "xtick.bottom": False,
54+
# "ytick.left": False,
7355
"font.family": "serif",
7456
# Use 10pt font in plots, to match 10pt font in document
7557
"axes.labelsize": (10 / 8) * font_size,
@@ -84,7 +66,6 @@ def load_plt_sns_configs(font_size=8, **rcparams):
8466
"figure.subplot.bottom": 0.12,
8567
"figure.subplot.top": 0.944,
8668
"figure.subplot.wspace": 0.071,
87-
"figure.subplot.hspace": 0.2,
8869
}
8970
plt.rcParams.update(base_style)
9071

@@ -104,13 +85,13 @@ def load_plt_sns_configs(font_size=8, **rcparams):
10485
plt.rcParams.update(rcparams)
10586

10687

107-
def val_to_color(col, cmap="RdYlBu_r"):
88+
def feature_to_color(col, cmap="RdYlBu_r"):
10889
"""
10990
Converts a column of values to hex-type colors.
11091
11192
Parameters
11293
----------
113-
col : array-like of shape (n_samples,)
94+
col : {list, array-like} of shape (n_samples,)
11495
Values to convert to hex-type color code
11596
11697
cmap : str or `~matplotlib.colors.Colormap`
@@ -124,6 +105,9 @@ def val_to_color(col, cmap="RdYlBu_r"):
124105
colors = _optional_import("matplotlib.colors")
125106
cm = _optional_import("matplotlib.cm")
126107

108+
if type(col) == list:
109+
col = np.array(col)
110+
127111
norm = colors.Normalize(vmin=col.min(), vmax=col.max(), clip=True)
128112
mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
129113
rgba = mapper.to_rgba(col)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
from .._visualization import set_matplotlib_style, feature_to_color
5+
6+
7+
def test_set_matplotlib_style():
8+
default_params = dict(plt.rcParams)
9+
set_matplotlib_style(12)
10+
new_params = dict(plt.rcParams)
11+
changed_params = [
12+
key for key in new_params.keys() if new_params[key] != default_params[key]
13+
]
14+
assert len(changed_params) > 1
15+
16+
17+
def test_feature_to_color():
18+
colors = feature_to_color(np.array([1, 2, 3, 4, 5]))
19+
colors2 = feature_to_color([1, 2, 3, 4, 5])
20+
assert (colors == colors2).all()
21+
assert colors.size == np.unique(colors).size

0 commit comments

Comments
 (0)