Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions datastock/_class1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import _class1_binning
from . import _class1_interpolate
from . import _class1_uniformize
from . import _class1_color_touch as _color_touch
from . import _export_dataframe
from . import _find_plateau

Expand Down Expand Up @@ -923,6 +924,32 @@ def interpolate(
inplace=inplace,
)

# ---------------------
# color touch array
# ---------------------

def get_color_touch(
self,
data=None,
dcolor=None,
# options
color_default=None,
vmin=None,
vmax=None,
log=None,
):

return _color_touch.main(
coll=self,
data=data,
dcolor=dcolor,
# options
color_default=color_default,
vmin=vmin,
vmax=vmax,
log=log,
)

# ---------------------
# Methods computing correlations
# ---------------------
Expand Down
269 changes: 269 additions & 0 deletions datastock/_class1_color_touch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 28 08:53:00 2025

@author: dvezinet
"""


import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import datastock as ds


# ###############################################################
# ###############################################################
# Main
# ###############################################################


def main(
coll=None,
data=None,
dcolor=None,
# options
color_default=None,
vmin=None,
vmax=None,
log=None,
):

# ------------------
# check inputs
# ------------------

data, dcolor, color_default, vmin, vmax, log = _check(
coll=coll,
data=data,
dcolor=dcolor,
color_default=color_default,
vmin=vmin,
vmax=vmax,
log=log,
)

# ------------------
# initialize
# ------------------

shape = data.shape + (4,)
color = np.zeros(shape, dtype=float)

# ------------------
# compute - alpha
# ------------------

if log is True:
vmin = np.log10(vmin)
vmax = np.log10(vmax)

alpha = (np.log10(data) - vmin) / (vmax - vmin)

else:
alpha = (data - vmin) / (vmax - vmin)

# ------------------
# compute - colors
# ------------------

for k0, v0 in dcolor.items():

sli = (v0['ind'], slice(0, 3))
color[sli] = v0['color']

sli = tuple([slice(None) for ii in range(data.ndim)] + [-1])
color[sli] = alpha

# ------------------
# output
# ------------------

lcol = set([v0['color'] for v0 in dcolor.values()])
dcolor = {
'color': color,
'meaning': {
kc: [k0 for k0, v0 in dcolor.items() if v0['color'] == kc]
for kc in lcol
},
}

return dcolor


# ###############################################################
# ###############################################################
# check
# ###############################################################


def _check(
coll=None,
data=None,
dcolor=None,
# options
color_default=None,
vmin=None,
vmax=None,
log=None,
):

# ------------------
# data
# ------------------

lc = [
isinstance(data, np.ndarray),
isinstance(data, str) and data in coll.ddata.keys(),
]
if lc[0]:
pass
elif lc[1]:
data = coll.ddata[data]['data']
else:
msg = (
"Arg data must be a np.ndarray or a key to an existing data!\n"
f"Provided: {data}\n"
)
raise Exception(msg)


# ------------------
# dcolor
# ------------------

# --------------------
# dcolor format check

c0 = (
isinstance(dcolor, dict)
and all([
isinstance(k0, str)
and isinstance(v0, dict)
and sorted(v0.keys()) == ['color', 'ind']
for k0, v0 in dcolor.items()
])
)
if not c0:
msg = (
"Arg dcolor must be a dict of sub-dicts of shape:\n"
"\t- 'key0': {'ind': ..., 'color': ...}\n"
"\t- ...\n"
"\t- 'keyN': {'ind': ..., 'color': ...}\n"
f"Provided:\n{dcolor}\n"
)
raise Exception(msg)

# --------------------
# ind and color checks

dfail = {}
shape = data.shape
for k0, v0 in dcolor.items():

c0 = (
isinstance(v0['ind'], np.ndarray)
and v0['ind'].shape == data.shape
and v0['ind'].dtype == bool
)
if not c0:
msg = f"'ind' must be a {shape} bool array, not {v0['ind']}"
dfail[k0] = (msg,)

if not mcolors.is_color_like(v0['color']):
msg = f"'color' must be color-like, not {v0['color']}"
if k0 in dfail:
dfail[k0] = dfail[k0] + (msg,)
else:
dfail[k0] = (msg,)

# raise exception
if len(dfail) > 0:
lmax = np.max([len(f"\t- {k0}: ") for k0 in dfail.keys()])
lstr = [
f"\t- {k0}:\n".ljust(lmax) + '\n'.join([
"".ljust(lmax+4) + f"\t- {v1}".rjust(lmax)
for ii, v1 in enumerate(v0)
])
for k0, v0 in dfail.items()
]
msg = (
"Arg dcolor, the following keys have incorrect keys / values:\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# ----------------------
# format colors to rgb

dcol = {}
for k0, v0 in dcolor.items():
if np.any(v0['ind']):
dcol[k0] = {
'ind': v0['ind'],
'color': mcolors.to_rgb(v0['color']),
}

# ------------------
# color_default
# ------------------

if color_default is None:
color_default = 'k'
if not mcolors.is_color_like(color_default):
msg = (
"Arg color_default must be color-like!\n"
f"Provided: {color_default}\n"
)
raise Exception(msg)

color_default = mcolors.to_rgb(color_default)

# ------------------
# vmin, vmax
# ------------------

vmin0 = np.nanmin(data)
vmax0 = np.nanmax(data)

# vmin
if vmin is None:
vmin = vmin0
c0 = (np.isscalar(vmin) and np.isfinite(vmin) and vmin < vmax0)
if not c0:
msg = (
f"Arg vmin must be a finite scalar below max ({vmax0})\n"
f"Provided: {vmin}\n"
)
raise Exception(msg)

# vmax
if vmax is None:
vmax = vmax0
c0 = (np.isscalar(vmax) and np.isfinite(vmax) and vmax > vmin0)
if not c0:
msg = (
f"Arg vmax must be a finite scalar above min ({vmin0})\n"
f"Provided: {vmax}\n"
)
raise Exception(msg)

# ordering
if vmin >= vmax:
msg = (
"Arg vmin must be below vmax!\n"
f"Provided:\n\t- vmin = {vmin}\n\t- vmax = {vmax}\n"
)
raise Exception(msg)

# ------------------
# log
# ------------------

log = ds._generic_check._check_var(
log, 'log',
types=bool,
default=False,
)

return data, dcol, color_default, vmin, vmax, log
2 changes: 1 addition & 1 deletion datastock/_class1_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,4 +1233,4 @@ def _extract_select(
# lkey=[idq2dR],
# return_all=True,
# )
# return out
# return out
22 changes: 17 additions & 5 deletions datastock/_class1_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,8 @@ def _xunique(dout=None, domain=None):
}

# Number of Nones expected
nNone = 1 + len(domain)
ndom = 0 if domain is None else len(domain)
nNone = 1 + ndom

# check
dwrong = {k0: v0 for k0, v0 in dind.items() if len(v0) != nNone}
Expand All @@ -1583,7 +1584,7 @@ def _xunique(dout=None, domain=None):
f"\t- {k0}: {dout[k0]['ref']} => {v0}" for k0, v0 in dwrong.items()
]
msg = (
"Interpolate unique pt => ref should have nNone = 1 + {len(domain)}:\n"
"Interpolate unique pt => ref should have nNone = 1 + {ndom}:\n"
+ "\n".join(lstr)
)
raise Exception(msg)
Expand Down Expand Up @@ -1626,7 +1627,12 @@ def _store(
ldata = list(set(itt.chain.from_iterable([
v0['ref'] for v0 in dout.values()
])))
coll2 = coll.extract(keys=ldata, vectors=True)

coll2 = coll.extract(
keys=ldata,
inc_vectors=True,
return_keys=False,
)

# -------------
# store_keys
Expand All @@ -1644,7 +1650,13 @@ def _store(
excluded=lout,
)

assert len(store_keys) == len(dout)
if len(store_keys) != len(dout):
msg = (
"Nb of store_keys != nb of keys in dout!\n"
f"\t- store_keys:\n{store_keys}\n "
f"\t- dout.keys():\n{sorted(dout.keys())}\n "
)
raise Exception(msg)

# ---------
# add data
Expand All @@ -1658,4 +1670,4 @@ def _store(
units=v0['units'],
)

return coll2
return coll2
Loading
Loading