diff --git a/datastock/_class1.py b/datastock/_class1.py index d4003c9..9e61983 100644 --- a/datastock/_class1.py +++ b/datastock/_class1.py @@ -21,6 +21,7 @@ from . import _class1_binning from . import _class1_interpolate from . import _class1_uniformize +from . import _class1_color_touch as _color_touch from . import _export_dataframe from . import _find_plateau @@ -923,6 +924,32 @@ def interpolate( inplace=inplace, ) + # --------------------- + # color touch array + # --------------------- + + def get_color_touch( + self, + data=None, + dcolor=None, + # options + color_default=None, + vmin=None, + vmax=None, + log=None, + ): + + return _color_touch.main( + coll=self, + data=data, + dcolor=dcolor, + # options + color_default=color_default, + vmin=vmin, + vmax=vmax, + log=log, + ) + # --------------------- # Methods computing correlations # --------------------- diff --git a/datastock/_class1_color_touch.py b/datastock/_class1_color_touch.py new file mode 100644 index 0000000..83ac4b3 --- /dev/null +++ b/datastock/_class1_color_touch.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Feb 28 08:53:00 2025 + +@author: dvezinet +""" + + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import datastock as ds + + +# ############################################################### +# ############################################################### +# Main +# ############################################################### + + +def main( + coll=None, + data=None, + dcolor=None, + # options + color_default=None, + vmin=None, + vmax=None, + log=None, +): + + # ------------------ + # check inputs + # ------------------ + + data, dcolor, color_default, vmin, vmax, log = _check( + coll=coll, + data=data, + dcolor=dcolor, + color_default=color_default, + vmin=vmin, + vmax=vmax, + log=log, + ) + + # ------------------ + # initialize + # ------------------ + + shape = data.shape + (4,) + color = np.zeros(shape, dtype=float) + + # ------------------ + # compute - alpha + # ------------------ + + if log is True: + vmin = np.log10(vmin) + vmax = np.log10(vmax) + + alpha = (np.log10(data) - vmin) / (vmax - vmin) + + else: + alpha = (data - vmin) / (vmax - vmin) + + # ------------------ + # compute - colors + # ------------------ + + for k0, v0 in dcolor.items(): + + sli = (v0['ind'], slice(0, 3)) + color[sli] = v0['color'] + + sli = tuple([slice(None) for ii in range(data.ndim)] + [-1]) + color[sli] = alpha + + # ------------------ + # output + # ------------------ + + lcol = set([v0['color'] for v0 in dcolor.values()]) + dcolor = { + 'color': color, + 'meaning': { + kc: [k0 for k0, v0 in dcolor.items() if v0['color'] == kc] + for kc in lcol + }, + } + + return dcolor + + +# ############################################################### +# ############################################################### +# check +# ############################################################### + + +def _check( + coll=None, + data=None, + dcolor=None, + # options + color_default=None, + vmin=None, + vmax=None, + log=None, +): + + # ------------------ + # data + # ------------------ + + lc = [ + isinstance(data, np.ndarray), + isinstance(data, str) and data in coll.ddata.keys(), + ] + if lc[0]: + pass + elif lc[1]: + data = coll.ddata[data]['data'] + else: + msg = ( + "Arg data must be a np.ndarray or a key to an existing data!\n" + f"Provided: {data}\n" + ) + raise Exception(msg) + + + # ------------------ + # dcolor + # ------------------ + + # -------------------- + # dcolor format check + + c0 = ( + isinstance(dcolor, dict) + and all([ + isinstance(k0, str) + and isinstance(v0, dict) + and sorted(v0.keys()) == ['color', 'ind'] + for k0, v0 in dcolor.items() + ]) + ) + if not c0: + msg = ( + "Arg dcolor must be a dict of sub-dicts of shape:\n" + "\t- 'key0': {'ind': ..., 'color': ...}\n" + "\t- ...\n" + "\t- 'keyN': {'ind': ..., 'color': ...}\n" + f"Provided:\n{dcolor}\n" + ) + raise Exception(msg) + + # -------------------- + # ind and color checks + + dfail = {} + shape = data.shape + for k0, v0 in dcolor.items(): + + c0 = ( + isinstance(v0['ind'], np.ndarray) + and v0['ind'].shape == data.shape + and v0['ind'].dtype == bool + ) + if not c0: + msg = f"'ind' must be a {shape} bool array, not {v0['ind']}" + dfail[k0] = (msg,) + + if not mcolors.is_color_like(v0['color']): + msg = f"'color' must be color-like, not {v0['color']}" + if k0 in dfail: + dfail[k0] = dfail[k0] + (msg,) + else: + dfail[k0] = (msg,) + + # raise exception + if len(dfail) > 0: + lmax = np.max([len(f"\t- {k0}: ") for k0 in dfail.keys()]) + lstr = [ + f"\t- {k0}:\n".ljust(lmax) + '\n'.join([ + "".ljust(lmax+4) + f"\t- {v1}".rjust(lmax) + for ii, v1 in enumerate(v0) + ]) + for k0, v0 in dfail.items() + ] + msg = ( + "Arg dcolor, the following keys have incorrect keys / values:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ---------------------- + # format colors to rgb + + dcol = {} + for k0, v0 in dcolor.items(): + if np.any(v0['ind']): + dcol[k0] = { + 'ind': v0['ind'], + 'color': mcolors.to_rgb(v0['color']), + } + + # ------------------ + # color_default + # ------------------ + + if color_default is None: + color_default = 'k' + if not mcolors.is_color_like(color_default): + msg = ( + "Arg color_default must be color-like!\n" + f"Provided: {color_default}\n" + ) + raise Exception(msg) + + color_default = mcolors.to_rgb(color_default) + + # ------------------ + # vmin, vmax + # ------------------ + + vmin0 = np.nanmin(data) + vmax0 = np.nanmax(data) + + # vmin + if vmin is None: + vmin = vmin0 + c0 = (np.isscalar(vmin) and np.isfinite(vmin) and vmin < vmax0) + if not c0: + msg = ( + f"Arg vmin must be a finite scalar below max ({vmax0})\n" + f"Provided: {vmin}\n" + ) + raise Exception(msg) + + # vmax + if vmax is None: + vmax = vmax0 + c0 = (np.isscalar(vmax) and np.isfinite(vmax) and vmax > vmin0) + if not c0: + msg = ( + f"Arg vmax must be a finite scalar above min ({vmin0})\n" + f"Provided: {vmax}\n" + ) + raise Exception(msg) + + # ordering + if vmin >= vmax: + msg = ( + "Arg vmin must be below vmax!\n" + f"Provided:\n\t- vmin = {vmin}\n\t- vmax = {vmax}\n" + ) + raise Exception(msg) + + # ------------------ + # log + # ------------------ + + log = ds._generic_check._check_var( + log, 'log', + types=bool, + default=False, + ) + + return data, dcol, color_default, vmin, vmax, log \ No newline at end of file diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 374bab1..59da29a 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -306,11 +306,26 @@ def test05_show(self): self.st.show_data() self.st.show_obj() + # ------------------------ + # dcolor + # ------------------------ + + def test06_get_dcolor_touch(self): + xx = np.arange(50) + aa = np.exp(-(xx[:, None]-25)**2/10**2 - (xx[None, :]-25)**2/10**2) + ind = (aa>0.3) & (np.arange(50)[None, :] > 25) + dcolor = self.st.get_color_touch( + aa, + dcolor={'foo': {'ind': ind, 'color': 'r'}} + ) + assert dcolor['color'].shape == aa.shape + (4,) + assert dcolor['meaning'][(1.0, 0.0, 0.0)] == ['foo'] + # ------------------------ # Interpolate # ------------------------ - def test06_get_ref_vector(self): + def test07_get_ref_vector(self): ( hasref, hasvector, ref, key_vector, @@ -325,13 +340,13 @@ def test06_get_ref_vector(self): assert values.size == dind['ind'].size == 4 assert dind['indr'].shape == (2, 4) - def test07_get_ref_vector_common(self): + def test08_get_ref_vector_common(self): hasref, ref, key, val, dout = self.st.get_ref_vector_common( keys=['t0', 'prof0', 'prof1', 't3'], dim='time', ) - def test08_domain_ref(self): + def test09_domain_ref(self): domain = { 'nx': [1.5, 2], @@ -347,7 +362,7 @@ def test08_domain_ref(self): lk = list(domain.keys()) assert all([isinstance(dout[k0]['ind'], np.ndarray) for k0 in lk]) - def test09_binning(self): + def test10_binning(self): bins = np.linspace(1, 5, 8) lk = [ @@ -399,7 +414,7 @@ def test09_binning(self): ) raise Exception(msg) - def test10_interpolate(self): + def test11_interpolate(self): lk = ['y', 'y', 'prof0', 'prof0', 'prof0', '3d'] lref = [None, 'nx', 't0', ['nt0', 'nx'], ['t0', 'x'], ['t0', 'x']] @@ -443,7 +458,7 @@ def test10_interpolate(self): msg = str(dout[kk]['data'].shape, shape, kk, rr) raise Exception(msg) - def test11_interpolate_common_refs(self): + def test12_interpolate_common_refs(self): lk = ['3d', '3d', '3d'] lref = ['t0', ['nt0', 'nx'], ['nx']] lrefc = ['nc', 'nc', 'nt0'] @@ -519,17 +534,17 @@ def test11_interpolate_common_refs(self): # Plotting # ------------------------ - def test12_plot_as_array_1d(self): + def test13_plot_as_array_1d(self): dax = self.st.plot_as_array(key='t0') plt.close('all') del dax - def test13_plot_as_array_2d(self): + def test14_plot_as_array_2d(self): dax = self.st.plot_as_array(key='prof0') plt.close('all') del dax - def test14_plot_as_array_2d_log(self): + def test15_plot_as_array_2d_log(self): dax = self.st.plot_as_array( key='pec', keyX='ne', keyY='Te', dscale={'data': 'log'}, @@ -537,17 +552,17 @@ def test14_plot_as_array_2d_log(self): plt.close('all') del dax - def test15_plot_as_array_3d(self): + def test16_plot_as_array_3d(self): dax = self.st.plot_as_array(key='3d', dvminmax={'keyX': {'min': 0}}) plt.close('all') del dax - def test16_plot_as_array_3d_ZNonMonot(self): + def test17_plot_as_array_3d_ZNonMonot(self): dax = self.st.plot_as_array(key='3d', keyZ='y') plt.close('all') del dax - def test17_plot_as_array_4d(self): + def test18_plot_as_array_4d(self): dax = self.st.plot_as_array(key='4d', dscale={'keyU': 'linear'}) plt.close('all') del dax @@ -557,7 +572,7 @@ def test17_plot_as_array_4d(self): # plt.close('all') # del dax - def test19_plot_as_profile1d(self): + def test20_plot_as_profile1d(self): dax = self.st.plot_as_profile1d( key='prof0', key_time='t0', @@ -591,7 +606,7 @@ def test19_plot_as_profile1d(self): # File handling # ------------------------ - def test21_copy_equal(self): + def test22_copy_equal(self): st2 = self.st.copy() assert st2 is not self.st @@ -599,15 +614,15 @@ def test21_copy_equal(self): if msg is not True: raise Exception(msg) - def test22_get_nbytes(self): + def test23_get_nbytes(self): nb, dnb = self.st.get_nbytes() - def test23_save_pfe(self, verb=False): + def test24_save_pfe(self, verb=False): pfe = os.path.join(_PATH_OUTPUT, 'testsave.npz') self.st.save(pfe=pfe, return_pfe=False) os.remove(pfe) - def test24_saveload(self, verb=False): + def test25_saveload(self, verb=False): pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) st2 = load(pfe, verb=verb) # Just to check the loaded version works fine @@ -616,7 +631,7 @@ def test24_saveload(self, verb=False): raise Exception(msg) os.remove(pfe) - def test25_saveload_coll(self, verb=False): + def test26_saveload_coll(self, verb=False): pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) st = DataStock() st2 = load(pfe, coll=st, verb=verb)