diff --git a/datastock/__init__.py b/datastock/__init__.py index f3a72d8..7d99b7d 100644 --- a/datastock/__init__.py +++ b/datastock/__init__.py @@ -1,10 +1,8 @@ - - from .version import __version__ from . import _generic_check from ._generic_utils_plot import * -from ._class import DataStock +from ._class04_Plots import Plots as Collection from ._saveload import load, get_files from ._direct_calls import * -from . import tests \ No newline at end of file +from . import tests diff --git a/datastock/_class.py b/datastock/_class.py deleted file mode 100644 index 77a9235..0000000 --- a/datastock/_class.py +++ /dev/null @@ -1,3 +0,0 @@ - - -from ._class3 import DataStock3 as DataStock diff --git a/datastock/_class0.py b/datastock/_class00.py similarity index 100% rename from datastock/_class0.py rename to datastock/_class00.py diff --git a/datastock/_class1.py b/datastock/_class01.py similarity index 84% rename from datastock/_class1.py rename to datastock/_class01.py index 27e2304..39d7ee0 100644 --- a/datastock/_class1.py +++ b/datastock/_class01.py @@ -1,26 +1,21 @@ # -*- coding: utf-8 -*- -# Built-in -import copy - - # Common import numpy as np import astropy.units as asunits # library-specific +from ._class00 import DataStock0 as Previous from . import _generic_check from . import _generic_utils -from . import _class1_check -from . import _class1_show -from ._class0 import * -from . import _class1_compute -from . import _class1_domain -from . import _class1_binning -from . import _class1_interpolate -from . import _class1_uniformize +from . import _class01_check as _check +from . import _class01_show as _show +from . import _class01_compute as _compute +from . import _class01_domain as _domain +from . import _class01_interpolate as _interpolate +from . import _class01_uniformize as _uniformize from . import _export_dataframe from . import _find_plateau @@ -32,7 +27,7 @@ ############################################# -class DataStock1(DataStock0): +class DataStock1(Previous): """ A generic class for handling data Provides methods for: @@ -117,7 +112,7 @@ def update( # Check consistency ( self._dref, self._ddata, self._dobj, self.__dlinks, - ) = _class1_check._consistency( + ) = _check._consistency( dobj=dobj, dobj0=self._dobj, ddata=ddata, ddata0=self._ddata, dref=dref, dref0=self._dref, @@ -173,7 +168,7 @@ def remove_ref(self, key=None, propagate=None): """ Remove a ref (or list of refs) and all associated data """ ( self._dref, self._ddata, self._dobj, self.__dlinks, - ) = _class1_check._remove_ref( + ) = _check._remove_ref( key=key, dref0=self._dref, ddata0=self._ddata, dobj0=self._dobj, @@ -189,7 +184,7 @@ def remove_data(self, key=None, propagate=True): """ Remove a data (or list of data) """ ( self._dref, self._ddata, self._dobj, self.__dlinks, - ) = _class1_check._remove_data( + ) = _check._remove_data( key=key, dref0=self._dref, ddata0=self._ddata, dobj0=self._dobj, @@ -205,7 +200,7 @@ def remove_obj(self, key=None, which=None, propagate=True): """ Remove a data (or list of data) """ ( self._dref, self._ddata, self._dobj, self.__dlinks, - ) = _class1_check._remove_obj( + ) = _check._remove_obj( key=key, which=which, propagate=propagate, @@ -247,7 +242,7 @@ def remove_all(self, excluded=None): def __check_which(self, which=None, return_dict=None): """ Check which in ['data'] + list(self._dobj.keys() """ - return _class1_check._check_which( + return _check._check_which( dref=self._dref, ddata=self._ddata, dobj=self._dobj, @@ -266,7 +261,7 @@ def get_lparam(self, which=None, for_show=None): which, dd = self.__check_which(which, return_dict=True) if which in ['ref', 'data']: for_show = False - return _class1_show._get_lparam(dd=dd, for_show=for_show) + return _show._get_lparam(dd=dd, for_show=for_show) def get_param( self, @@ -291,7 +286,7 @@ def get_param( """ which, dd = self.__check_which(which, return_dict=True) - return _class1_check._get_param( + return _check._get_param( dd=dd, dd_name=which, param=param, key=key, ind=ind, returnas=returnas, ) @@ -322,7 +317,7 @@ def set_param( """ which, dd = self.__check_which(which, return_dict=True) - param = _class1_check._set_param( + param = _check._set_param( dd=dd, dd_name=which, param=param, value=value, ind=ind, key=key, distribute=distribute, @@ -340,7 +335,7 @@ def add_param( ): """ Add a parameter, optionnally also set its value """ which, dd = self.__check_which(which, return_dict=True) - param = _class1_check._add_param( + param = _check._add_param( dd=dd, dd_name=which, param=param, @@ -358,7 +353,7 @@ def remove_param( ): """ Remove a parameter, none by default, all if param = 'all' """ which, dd = self.__check_which(which, return_dict=True) - _class1_check._remove_param( + _check._remove_param( dd=dd, dd_name=which, param=param, @@ -427,7 +422,7 @@ def propagate_indices_per_ref( - 'index': set matching indices (default) - param: set matching monotonous quantities depending on ref """ - _class1_compute.propagate_indices_per_ref( + _compute.propagate_indices_per_ref( ref=ref, lref=lref, ldata=ldata, @@ -470,7 +465,7 @@ def extract( """ - return _class1_compute._extract_instance( + return _compute._extract_instance( self, keys=keys, # optional includes @@ -529,7 +524,7 @@ def select(self, which=None, log=None, returnas=None, **kwdargs): """ which, dd = self.__check_which(which, return_dict=True) - return _class1_check._select( + return _check._select( dd=dd, dd_name=which, log=log, returnas=returnas, **kwdargs, @@ -544,7 +539,7 @@ def _ind_tofrom_key( ): """ Return ind from key or key from ind for all data """ which, dd = self.__check_which(which, return_dict=True) - return _class1_check._ind_tofrom_key( + return _check._ind_tofrom_key( dd=dd, dd_name=which, ind=ind, key=key, returnas=returnas, ) @@ -556,7 +551,14 @@ def _get_sort_index(self, which=None, param=None): return if param == 'key': - ind = np.argsort(list(dd.keys())) + if which == 'ref': + lk = list(self.dref.keys()) + elif which == 'data': + lk = list(self.ddata.keys()) + else: + lk = list(self.dobj.get(which, {}).keys()) + ind = np.argsort(lk) + elif isinstance(param, str): ind = np.argsort( self.get_param(param, which=which, returnas=np.ndarray)[param] @@ -564,6 +566,7 @@ def _get_sort_index(self, which=None, param=None): else: msg = "Arg param must be a valid str\n Provided: {}".format(param) raise Exception(msg) + return ind def sortby(self, param=None, order=None, which=None): @@ -640,7 +643,9 @@ def get_ref_vector( >>> st.add_data(key='t0', data=t0) >>> st.add_data(key='x', data=x) >>> st.add_data(key='xt', data=xt) - >>> hasref, hasvect, ref, key_vect, dind = st.get_ref_vector(key='xt', ref='nt', values=[2, 3, 3.1, 5]) + >>> hasref, hasvect, ref, key_vect, dind = st.get_ref_vector( + >>> key='xt', ref='nt', values=[2, 3, 3.1, 5], + >>> ) In the above example: - hasref = True: 'xt' has 'nt' has ref @@ -651,13 +656,13 @@ def get_ref_vector( 'key': [2, 3, 3.1, 5], # the desired time points 'ind': [2, 3, 3, 5], # the indices of t in t0 'indu': [2, 3, 5] # the unique indices of t in t0 - 'indr': (3, 4), # bool array showing, for each indu, matching ind + 'indr': (3, 4), # bool array with ind for each indu 'indok': [True, False, ...] } """ - return _class1_uniformize.get_ref_vector( + return _uniformize.get_ref_vector( # ressources ddata=self._ddata, dref=self._dref, @@ -712,7 +717,7 @@ def get_ref_vector_common( """ - return _class1_uniformize.get_ref_vector_common( + return _uniformize.get_ref_vector_common( # ressources ddata=self._ddata, dref=self._dref, @@ -750,7 +755,7 @@ def uniformize( returnas=None, ): - return _class1_uniformize.uniformize( + return _uniformize.uniformize( coll=self, keys=keys, refs=refs, @@ -770,100 +775,7 @@ def get_domain_ref( """ Return a dict of index of valid steps based on desired domain """ - return _class1_domain.domain_ref(coll=self, domain=domain) - - # --------------------- - # Binning - # --------------------- - - def binning( - self, - data=None, - data_units=None, - axis=None, - # binning - bins0=None, - bins1=None, - bin_data0=None, - bin_data1=None, - bin_units0=None, - # kind of binning - integrate=None, - statistic=None, - # options - safety_ratio=None, - dref_vector=None, - verb=None, - returnas=None, - # storing - store=None, - store_keys=None, - ): - """ Return the binned data - - data: the data on which to apply binning, can be - - a list of np.ndarray to be binned - (any dimension as long as they all have the same) - - a list of keys to ddata items sharing the same refs - - data_units: str only necessary if data is a list of arrays - - axis: int or array of int indices - the axis of data along which to bin - data will be flattened along all those axis priori to binning - If None, assumes bin_data is not variable and uses all its axis - - bins0: the bins (centers), can be - - a 1d vector of monotonous bins - - a int, used to compute a bins vector from max(data), min(data) - - bin_data0: the data used to compute binning indices, can be: - - a str, key to a ddata item - - a np.ndarray - _ a list of any of the above if each data has different size along axis - - bin_units: str - only used if integrate = True and bin_data is a np.ndarray - - integrate: bool - flag indicating whether binning is used for integration - Implies that: - Only usable for 1d binning (axis has to be a single index) - data is multiplied by the underlying bin_data0 step prior to binning - - statistic: str - the statistic kwd feed to scipy.stats.binned_statistic() - automatically set to 'sum' if integrate = True - - store: bool - If True, will sotre the result in ddata - Only possible if all (data, bin_data and bin) are provided as keys - - """ - - return _class1_binning.binning( - coll=self, - data=data, - data_units=data_units, - axis=axis, - # binning - bins0=bins0, - bins1=bins1, - bin_data0=bin_data0, - bin_data1=bin_data1, - bin_units0=bin_units0, - # kind of binning - integrate=integrate, - statistic=statistic, - # options - safety_ratio=safety_ratio, - dref_vector=dref_vector, - verb=verb, - returnas=returnas, - # storing - store=store, - store_keys=store_keys, - ) + return _domain.domain_ref(coll=self, domain=domain) # --------------------- # Interpolation @@ -897,7 +809,7 @@ def interpolate( """ Interpolate keys in desired dimension """ - return _class1_interpolate.interpolate( + return _interpolate.interpolate( coll=self, # interpolation base keys=keys, @@ -935,7 +847,7 @@ def compute_correlations( verb=None, returnas=None, ): - return _class1_compute.correlations( + return _compute.correlations( data=data, ref=ref, correlations=correlations, @@ -963,7 +875,7 @@ def show( returnas=False, ): """ Summary description of the object content """ - return _class1_show.main( + return _show.main( coll=self, show_which=show_which, show=show, @@ -978,7 +890,7 @@ def show( ) def _get_show_obj(self, which=None): - return _class1_show._show_obj_def + return _show._show_obj_def def show_data(self): self.show(show_which=['ref', 'data']) @@ -1003,7 +915,7 @@ def show_details( returnas=False, ): """ Summary description of the object content """ - return _class1_show.main_details( + return _show.main_details( coll=self, which=which, key=key, @@ -1052,4 +964,4 @@ def show_links(self): __all__ = [ sorted([k0 for k0 in locals() if k0.startswith('DataStock')])[-1] -] \ No newline at end of file +] diff --git a/datastock/_class1_check.py b/datastock/_class01_check.py similarity index 100% rename from datastock/_class1_check.py rename to datastock/_class01_check.py diff --git a/datastock/_class1_compute.py b/datastock/_class01_compute.py similarity index 99% rename from datastock/_class1_compute.py rename to datastock/_class01_compute.py index a137eea..a204115 100644 --- a/datastock/_class1_compute.py +++ b/datastock/_class01_compute.py @@ -1233,4 +1233,4 @@ def _extract_select( # lkey=[idq2dR], # return_all=True, # ) - # return out \ No newline at end of file + # return out diff --git a/datastock/_class1_domain.py b/datastock/_class01_domain.py similarity index 99% rename from datastock/_class1_domain.py rename to datastock/_class01_domain.py index d483833..de3b74b 100644 --- a/datastock/_class1_domain.py +++ b/datastock/_class01_domain.py @@ -244,4 +244,4 @@ def _set_ind_from_domain( ind = ind_in & (~ind_out) - return ind \ No newline at end of file + return ind diff --git a/datastock/_class1_interpolate.py b/datastock/_class01_interpolate.py similarity index 100% rename from datastock/_class1_interpolate.py rename to datastock/_class01_interpolate.py diff --git a/datastock/_class1_show.py b/datastock/_class01_show.py similarity index 100% rename from datastock/_class1_show.py rename to datastock/_class01_show.py diff --git a/datastock/_class1_uniformize.py b/datastock/_class01_uniformize.py similarity index 100% rename from datastock/_class1_uniformize.py rename to datastock/_class01_uniformize.py diff --git a/datastock/_class2.py b/datastock/_class02.py similarity index 91% rename from datastock/_class2.py rename to datastock/_class02.py index 8d08e05..ce2fb98 100644 --- a/datastock/_class2.py +++ b/datastock/_class02.py @@ -13,9 +13,8 @@ from . import _generic_check -from ._class1 import * -from . import _class2_interactivity -from . import _class1_compute +from ._class01 import DataStock1 as Previous +from . import _class02_interactivity as _interactivity # ################################################################# @@ -24,7 +23,7 @@ # ################################################################# -class DataStock2(DataStock1): +class DataStock2(Previous): """ Handles matplotlib interactivity """ _LPAXES = ['axes', 'type'] @@ -150,7 +149,7 @@ def add_mobile( else: msg = ( f"In dmobile['{key}']:\n" - "Nb. of different dtypes must match nb of different data!\n" + "Nb. of diff. dtypes must match nb of diff. data!\n" f"\t- dtype: {dtype}\n" f"\t- data: {data}\n" ) @@ -193,15 +192,15 @@ def add_axes( # check refx, refy # if refx is None and refy is None: - # msg = f"Please provide at least refx or refy for axes {key}!" - # raise Exception(msg) + # msg = f"Please provide at least refx or refy for axes {key}!" + # raise Exception(msg) if isinstance(refx, str): refx = [refx] if isinstance(refy, str): refy = [refy] - c0 =( + c0 = ( isinstance(refx, list) and all([rr in self._dref.keys() for rr in refx]) ) @@ -209,7 +208,7 @@ def add_axes( msg = "Arg refx must be a list of valid ref keys!" raise Exception(msg) - c0 =( + c0 = ( isinstance(refy, list) and all([rr in self._dref.keys() for rr in refy]) ) @@ -311,7 +310,7 @@ def dinteractivity(self): # ------------------ def show_commands(self, verb=None, returnas=None): - return _class2_interactivity.show_commands( + return _interactivity.show_commands( verb=verb, returnas=returnas, ) @@ -363,7 +362,7 @@ def setup_interactivity( # ---------- # Check dgroup - dgroup, newgroup = _class2_interactivity._setup_dgroup( + dgroup, newgroup = _interactivity._setup_dgroup( dgroup=dgroup, dobj0=self._dobj, dref0=self._dref, @@ -372,7 +371,7 @@ def setup_interactivity( # ---------- # Check increment dict - dinc, newinc = _class2_interactivity._setup_dinc( + dinc, newinc = _interactivity._setup_dinc( dinc=dinc, lparam_ref=self.get_lparam(which='ref'), dref0=self._dref, @@ -381,7 +380,7 @@ def setup_interactivity( # ---------------------------------------------------------- # make sure all refs are known and are associated to a group - drefgroup = _class2_interactivity._setup_drefgroup( + drefgroup = _interactivity._setup_drefgroup( dref0=self._dref, dgroup=dgroup, ) @@ -389,8 +388,9 @@ def setup_interactivity( # add indices to ref for k0, v0 in self._dref.items(): if drefgroup[k0] is not None: + zeros = np.zeros((dgroup[drefgroup[k0]]['nmax'],), dtype=int) self.add_indices_per_ref( - indices=np.zeros((dgroup[drefgroup[k0]]['nmax'],), dtype=int), + indices=zeros, ref=k0, distribute=False, ) @@ -451,7 +451,7 @@ def setup_interactivity( # -------------------------- # update mobile with group, group_vis and func - _class2_interactivity._setup_mobile( + _interactivity._setup_mobile( dmobile=self._dobj['mobile'], dref=self._dref, ddata=self._ddata, @@ -460,7 +460,7 @@ def setup_interactivity( # -------------------- # axes mobile, refs and canvas - daxcan = dict.fromkeys(self._dobj['axes'].keys()) + # daxcan = dict.fromkeys(self._dobj['axes'].keys()) for k0, v0 in self._dobj['axes'].items(): # Update mobile @@ -486,7 +486,7 @@ def setup_interactivity( # --------- # dkeys - dkeys = _class2_interactivity._setup_keys(dkeys=dkeys, dgroup=dgroup) + dkeys = _interactivity._setup_keys(dkeys=dkeys, dgroup=dgroup) # implement dict for ii, (k0, v0) in enumerate(dkeys.items()): @@ -532,7 +532,7 @@ def setup_interactivity( **dinter, ) - _class2_interactivity._set_dbck( + _interactivity._set_dbck( lax=self._dobj['axes'].keys(), daxes=self._dobj['axes'], dcanvas=self._dobj['canvas'], @@ -600,43 +600,44 @@ def connect(self): if self._warn_ifnotInteractive(): return for k0, v0 in self._dobj['canvas'].items(): - keyp = v0['handle'].mpl_connect('key_press_event', self.onkeypress) - keyr = v0['handle'].mpl_connect('key_release_event', self.onkeypress) - butp = v0['handle'].mpl_connect('button_press_event', self.mouseclic) - res = v0['handle'].mpl_connect('resize_event', self.resize) - butr = v0['handle'].mpl_connect('button_release_event', self.mouserelease) - close = v0['handle'].mpl_connect('close_event', self.on_close) - draw = v0['handle'].mpl_connect('draw_event', self.on_draw) + hand = v0['handle'] + keyp = hand.mpl_connect('key_press_event', self.onkeypress) + keyr = hand.mpl_connect('key_release_event', self.onkeypress) + butp = hand.mpl_connect('button_press_event', self.mouseclic) + res = hand.mpl_connect('resize_event', self.resize) + butr = hand.mpl_connect('button_release_event', self.mouserelease) + close = hand.mpl_connect('close_event', self.on_close) + # draw = hand.mpl_connect('draw_event', self.on_draw) # Make sure resizing is doen before resize_event # works without re-initializing because not a Qt Action - v0['handle'].manager.toolbar.release = self.mouserelease + hand.manager.toolbar.release = self.mouserelease # v0['handle'].manager.toolbar.release_zoom = self.mouserelease # v0['handle'].manager.toolbar.release_pan = self.mouserelease # make sure home button triggers background update # requires re-initializing because home is a Qt Action # only created by toolbar.addAction() - v0['handle'].manager.toolbar.home = self.new_home + hand.manager.toolbar.home = self.new_home # if _init_toolbar() implemented (matplotlib > ) error = False - if hasattr(v0['handle'].manager.toolbar, '_init_toolbar'): + if hasattr(hand.manager.toolbar, '_init_toolbar'): try: - v0['handle'].manager.toolbar._init_toolbar() + hand.manager.toolbar._init_toolbar() except NotImplementedError: - v0['handle'].manager.toolbar.__init__( - v0['handle'], - v0['handle'].parent(), + hand.manager.toolbar.__init__( + hand, + hand.parent(), ) except Exception as err: error = err - elif hasattr(v0['handle'], 'parent'): + elif hasattr(hand, 'parent'): try: - v0['handle'].manager.toolbar.__init__( - v0['handle'], - v0['handle'].parent(), + hand.manager.toolbar.__init__( + hand, + hand.parent(), ) - except Exception as err: + except Exception: error = True else: error = True @@ -644,9 +645,8 @@ def connect(self): if error is not False: import platform import sys - import inspect - lstr0 = [f"\t- {k1}" for k1 in dir(v0['handle'])] - lstr1 = [f"\t- {k1}" for k1 in dir(v0['handle'].manager.toolbar)] + lstr0 = [f"\t- {k1}" for k1 in dir(hand)] + lstr1 = [f"\t- {k1}" for k1 in dir(hand.manager.toolbar)] msg = ( f"platform: {platform.platform()}\n" f"python: {sys.version}\n" @@ -657,7 +657,7 @@ def connect(self): + "\n".join(lstr1) ) if error is not True: - msg += '\n' + str(err) + msg += '\n' + str(error) warnings.warn(msg) self._dobj['canvas'][k0]['cid'] = { @@ -726,8 +726,8 @@ def _get_current_grouprefdata_from_kax(self, kax=None): # Get current group and ref groupx = self._dobj['axes'][kax]['groupx'] groupy = self._dobj['axes'][kax]['groupy'] - refx = self._dobj['axes'][kax]['refx'] - refy = self._dobj['axes'][kax]['refy'] + # refx = self._dobj['axes'][kax]['refx'] + # refy = self._dobj['axes'][kax]['refy'] # Get kinter kinter = list(self._dobj['interactivity'].keys())[0] @@ -802,7 +802,7 @@ def _getset_current_axref(self, event): types=str, allowed=lkax, ) - ax = self._dobj['axes'][kax]['handle'] + # ax = self._dobj['axes'][kax]['handle'] # Check axes is relevant and toolbar not active lc = [ @@ -831,8 +831,8 @@ def update_interactivity(self): cur_groupy = self._dobj['interactivity'][self.kinter]['cur_groupy'] cur_refx = self._dobj['interactivity'][self.kinter]['cur_refx'] cur_refy = self._dobj['interactivity'][self.kinter]['cur_refy'] - cur_datax = self._dobj['interactivity'][self.kinter]['cur_datax'] - cur_datay = self._dobj['interactivity'][self.kinter]['cur_datay'] + # cur_datax = self._dobj['interactivity'][self.kinter]['cur_datax'] + # cur_datay = self._dobj['interactivity'][self.kinter]['cur_datay'] # Propagate indices through refs if cur_refx is not None: @@ -871,7 +871,7 @@ def update_interactivity(self): ]) ] - self._update_mobiles(lmobiles=lmobiles) # 0.2 s + self._update_mobiles(lmobiles=lmobiles) # 0.2 s if self.debug: self.show_debug() @@ -904,7 +904,7 @@ def _update_mobiles(self, lmobiles=None): # ---- update data of group objects ---- 0.15 s for k0 in lmobiles: - _class2_interactivity._update_mobile( + _interactivity._update_mobile( dmobile=self._dobj['mobile'], dref=self._dref, ddata=self._ddata, @@ -913,20 +913,21 @@ def _update_mobiles(self, lmobiles=None): # --- Redraw all objects (due to background restore) --- 25 ms for k0, v0 in self._dobj['mobile'].items(): - v0['handle'].set_visible(v0['visible']) + hand = v0['handle'] + hand.set_visible(v0['visible']) try: - self._dobj['axes'][v0['axes']]['handle'].draw_artist(v0['handle']) + self._dobj['axes'][v0['axes']]['handle'].draw_artist(hand) except Exception as err: print() print(0, k0) # DB print(1, v0['axes']) # DB print(2, self._dobj['axes'][v0['axes']]['handle']) # DB - print(3, v0['handle']) # DB + print(3, hand) # DB print( 4, 'x and y data shapes: ', - [vv.shape for vv in v0['handle'].get_data()] + [vv.shape for vv in hand.get_data()] ) # DB - print(5, 'data: ', v0['handle'].get_data()) + print(5, 'data: ', hand.get_data()) print(err) # DB print() # DB @@ -941,7 +942,7 @@ def _update_mobiles(self, lmobiles=None): # ---------------------- def resize(self, event): - _class2_interactivity._set_dbck( + _interactivity._set_dbck( lax=self._dobj['axes'].keys(), daxes=self._dobj['axes'], dcanvas=self._dobj['canvas'], @@ -955,7 +956,7 @@ def new_home(self, *args): v0['handle'].manager.toolbar.__class__, v0['handle'].manager.toolbar, ).home(*args) - _class2_interactivity._set_dbck( + _interactivity._set_dbck( lax=self._dobj['axes'].keys(), daxes=self._dobj['axes'], dcanvas=self._dobj['canvas'], @@ -1004,7 +1005,9 @@ def mouseclic(self, event): cur_datay = self._dobj['interactivity'][kinter]['cur_datay'] shift = self._dobj['key']['shift']['val'] - ctrl = any([self._dobj['key'][ss]['val'] for ss in ['control', 'ctrl']]) + ctrl = any([ + self._dobj['key'][ss]['val'] for ss in ['control', 'ctrl'] + ]) # Update number of indices (for visibility) gax = [] @@ -1014,7 +1017,7 @@ def mouseclic(self, event): gax += self._dobj['axes'][kax]['groupy'] for gg in set([cur_groupx, cur_groupy]): if gg is not None and gg in gax: - out = _class2_interactivity._update_indices_nb( + out = _interactivity._update_indices_nb( group=gg, dgroup=self._dobj['group'], ctrl=ctrl, @@ -1055,7 +1058,7 @@ def mouseclic(self, event): and cur_refx in self._dobj['axes'][kax]['refx'] ) if c0x: - ix = _class2_interactivity._get_ix_for_refx_only_1or2d( + ix = _interactivity._get_ix_for_refx_only_1or2d( cur_data=cur_datax, cur_ref=cur_refx, eventdata=event.xdata, @@ -1072,7 +1075,7 @@ def mouseclic(self, event): and cur_refy in self._dobj['axes'][kax]['refy'] ) if c0y: - iy = _class2_interactivity._get_ix_for_refx_only_1or2d( + iy = _interactivity._get_ix_for_refx_only_1or2d( cur_data=cur_datay, cur_ref=cur_refy, eventdata=event.ydata, @@ -1120,15 +1123,15 @@ def mouserelease(self, event): if v0['handle'] == event.inaxes.figure.canvas ][0] mode = self._dobj['canvas'][can]['handle'].manager.toolbar.mode.lower() - c0 = 'pan' in mode + c0 = 'pan' in mode c1 = 'zoom' in mode if c0 or c1: kax = self._dobj['interactivity'][self.kinter]['cur_ax_panzoom'] if kax is None: msg = ( - "Make sure you release the mouse button on an axes !" - "\n Otherwise the background plot cannot be properly updated !" + "Make sure you release the mouse button on an axes!" + "\n Otherwise background plot can't be properly updated!" ) raise Exception(msg) ax = self._dobj['axes'][kax]['handle'] @@ -1142,7 +1145,7 @@ def mouserelease(self, event): ][0] for ax in lax ] - _class2_interactivity._set_dbck( + _interactivity._set_dbck( lax=lax, daxes=self._dobj['axes'], dcanvas=self._dobj['canvas'], @@ -1193,7 +1196,7 @@ def onkeypress(self, event): ln = np.r_[ngen, nmov, ngrp, nind] if np.any(ln > 1) or np.sum(ln) > 2: return - if np.sum(ln) == 2 and (ngrp == 1 or nind ==1 ): + if np.sum(ln) == 2 and (ngrp == 1 or nind == 1): return # only keep relevant keys @@ -1222,7 +1225,7 @@ def onkeypress(self, event): # group group = self._dobj['key'][event.key]['group'] cx = any([ - v0['groupx'] is not None and group in v0['groupx'] + v0['groupx'] is not None and group in v0['groupx'] for v0 in self._dobj['axes'].values() ]) if cx: @@ -1276,7 +1279,10 @@ def onkeypress(self, event): imax = self._dobj['group'][groupx]['nmaxcur'] ii = int(event.key) if ii > imax: - msg = "Set to current max index for group '{groupx}': {imax}" + msg = ( + f"Set to current max index for group '{groupx}':" + f" {imax}" + ) print(msg) ii = min(ii, imax) self._dobj['group'][groupx]['indcur'] = ii @@ -1286,7 +1292,10 @@ def onkeypress(self, event): imax = self._dobj['group'][groupy]['nmaxcur'] ii = int(event.key) if ii > imax: - msg = "Set to current max index for group '{groupy}': {imax}" + msg = ( + f"Set to current max index for group '{groupy}':" + f" {imax}" + ) print(msg) ii = min(ii, imax) self._dobj['group'][groupy]['indcur'] = ii @@ -1339,7 +1348,7 @@ def onkeypress(self, event): return # update nb of visible indices - out = _class2_interactivity._update_indices_nb( + out = _interactivity._update_indices_nb( group=group, dgroup=self._dobj['group'], ctrl=ctrl, @@ -1380,7 +1389,7 @@ def onkeypress(self, event): # ------------------- def on_close(self, event): - self.remove_all(excluded=['canvas']) # to avoid crash + self.remove_all(excluded=['canvas']) # to avoid crash print("\n---- CLOSING interactive figure ----") print(f"\tleft in dax: {self.get_nbytes()[0]/1000} Ko\n") @@ -1393,4 +1402,4 @@ def on_close(self, event): __all__ = [ sorted([k0 for k0 in locals() if k0.startswith('DataStock')])[-1] -] \ No newline at end of file +] diff --git a/datastock/_class2_interactivity.py b/datastock/_class02_interactivity.py similarity index 99% rename from datastock/_class2_interactivity.py rename to datastock/_class02_interactivity.py index 9cc62e0..35c7c26 100644 --- a/datastock/_class2_interactivity.py +++ b/datastock/_class02_interactivity.py @@ -9,7 +9,7 @@ from . import _generic_check from . import _generic_utils -from . import _class1_compute +from . import _class01_compute _INCREMENTS = [1, 10] @@ -256,7 +256,7 @@ def _setup_mobile( # functions for slicing dmobile[k0]['func_slice'] = [ - _class1_compute._get_slice( + _class01_compute._get_slice( laxis=dmobile[k0]['axis'][ii], ndim=( 1 if dmobile[k0]['data'][ii] == 'index' @@ -442,7 +442,7 @@ def _get_ix_for_refx_only_1or2d( raise NotImplementedError() # get index of datax corresponding to clicked point - return _class1_compute._get_index_from_data( + return _class01_compute._get_index_from_data( data=cd, data_pick=np.r_[eventdata], monot=monot, @@ -588,4 +588,4 @@ def _update_mobile(k0=None, dmobile=None, dref=None, ddata=None): # ddata[dmobile[k0]['data'][ii]]['data'][ # dmobile[k0]['func_slice'][ii](iref[ii]) # ] - # ) \ No newline at end of file + # ) diff --git a/datastock/_class03_Bins.py b/datastock/_class03_Bins.py new file mode 100644 index 0000000..5aa3efd --- /dev/null +++ b/datastock/_class03_Bins.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- + + +# Built-in +import copy + + +# local +from ._class02 import DataStock2 as Previous +from . import _class03_checks as _checks +from . import _class03_bin_vs_bs as _bin_vs_bs + + +__all__ = ['Bins'] + + +# ############################################################################# +# ############################################################################# +# +# ############################################################################# + + +class Bins(Previous): + + _which_bins = 'bins' + _ddef = copy.deepcopy(Previous._ddef) + _dshow = dict(Previous._dshow) + + _dshow.update({ + _which_bins: [ + 'nd', + 'shape_edges', + 'edges', + 'ref_edges', + 'is_linear', + 'is_log', + ], + }) + + # ----------------- + # bsplines + # ------------------ + + def add_bins( + self, + key=None, + edges=None, + # custom names + key_edges=None, + key_cents=None, + key_ref_edges=None, + key_ref_cents=None, + # additional attributes + **kwdargs, + ): + """ Add bin + + Defined from edges, which can be: + - np.ndarray or tuple of 2 + - key to existing monotnous array or tuple of 2 + + key names are generated automatically + But can also be specified: + - for creation + - or for refering to existing data + + """ + + # -------------- + # check inputs + + key, dref, ddata, dobj = _checks.check( + coll=self, + key=key, + edges=edges, + # custom names + key_edges=key_edges, + key_cents=key_cents, + key_ref_edges=key_ref_edges, + key_ref_cents=key_ref_cents, + # attributes + **kwdargs, + ) + + # -------------- + # update dict and crop if relevant + + self.update(dobj=dobj, ddata=ddata, dref=dref) + + def remove_bins( + self, + key=None, + propagate=None, + ): + + _checks.remove_bins( + coll=self, + key=key, + propagate=propagate, + ) + + # ----------------- + # binning tools + # ------------------ + + def binning( + self, + # data to be binned + data=None, + bin_data0=None, + bin_data1=None, + axis=None, + # bins + bins=None, + # kind of binning + integrate=None, + statistic=None, + # options + safety_ratio=None, + dref_vector=None, + verb=None, + returnas=None, + # storing + store=None, + store_keys=None, + ): + """ Bin data along ref_key + + Binning is treated here as an integral + Hence, if: + - the data has units [ph/eV] + - the ref_key has units [eV] + - the binned data has units [ph] + + return a dict with data and units per key + + data: the data on which to apply binning, can be + - a list of np.ndarray to be binned + (any dimension as long as they all have the same) + - a list of keys to ddata items sharing the same refs + + data_units: str only necessary if data is a list of arrays + + axis: int or array of int indices + the axis of data along which to bin + data will be flattened along all those axis priori to binning + If None, assumes bin_data is not variable and uses all its axis + + bins0: the bins (centers), can be + - a 1d vector of monotonous bins + - a int, used to compute a bins vector from max(data), min(data) + + bin_data0: the data used to compute binning indices, can be: + - a str, key to a ddata item + - a np.ndarray + - a list of any of the above if each data has diff. size along axis + + bin_units: str + only used if integrate = True and bin_data is a np.ndarray + + integrate: bool + flag indicating whether binning is used for integration + Implies that: + Only usable for 1d binning (axis has to be a single index) + data is multiplied by bin_data0 step prior to binning + + statistic: str + the statistic kwd feed to scipy.stats.binned_statistic() + automatically set to 'sum' if integrate = True + + store: bool + If True, will sotre the result in ddata + Only possible if all (data, bin_data and bin) are provided as keys + """ + + return _bin_vs_bs.main( + coll=self, + # data to be binned + data=data, + bin_data0=bin_data0, + bin_data1=bin_data1, + axis=axis, + # bins + bins=bins, + # kind of binning + integrate=integrate, + statistic=statistic, + # options + safety_ratio=safety_ratio, + dref_vector=dref_vector, + verb=verb, + returnas=returnas, + # storing + store=store, + store_keys=store_keys, + ) diff --git a/datastock/_class03_bin_vs_bs.py b/datastock/_class03_bin_vs_bs.py new file mode 100644 index 0000000..4e98937 --- /dev/null +++ b/datastock/_class03_bin_vs_bs.py @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jan 5 20:14:40 2023 + +@author: dvezinet +""" + + +import numpy as np + + +from . import _class03_binning as _binning + + +# ############################################################ +# ############################################################ +# interpolate spectral +# ############################################################ + + +def binning( + coll=None, + data=None, + data_units=None, + axis=None, + # binning + bins0=None, + bins1=None, + bin_data0=None, + bin_data1=None, + bin_units0=None, + # kind of binning + integrate=None, + statistic=None, + # options + safety_ratio=None, + dref_vector=None, + ref_vector_strategy=None, + verb=None, + returnas=None, + # storing + store=None, + store_keys=None, +): + """ Return the spectrally interpolated coefs + + Either E xor Ebins can be provided + - E: return interpolated coefs + - Ebins: return binned (integrated) coefs + """ + + # ---------- + # checks + + # keys + isbs, bin_data0 = _check_bs( + coll=coll, + bin_data0=bin_data0, + bin_data1=bin_data1, + ) + + # ---------- + # trivial + + nobin = False + if isbs: + + # add ref and data + kr, kd, ddatan, nobin = _interpolate( + coll=coll, + data=data, + data_units=data_units, + # binning + bins0=bins0, + bin_data0=bin_data0, + # options + dref_vector=dref_vector, + verb=verb, + store=store, + store_keys=store_keys, + ) + + # safety check + if nobin is False: + lk = list(ddatan.keys()) + data = [ddatan[k0]['data'] for k0 in lk] + bin_data0 = [ddatan[k0]['bin_data'] for k0 in lk] + + # -------------------- + # do the actua binning + + if nobin is False: + dout = _binning.main( + coll=coll, + data=data, + data_units=data_units, + axis=axis, + # binning + bins0=bins0, + bins1=bins1, + bin_data0=bin_data0, + bin_data1=bin_data1, + bin_units0=bin_units0, + # kind of binning + integrate=integrate, + statistic=statistic, + # options + safety_ratio=safety_ratio, + dref_vector=dref_vector, + ref_vector_strategy=ref_vector_strategy, + verb=verb, + returnas=True, + # storing + store=store, + store_keys=store_keys, + ) + + # -------------------------------- + # remove intermediate ref and data + + if isbs is True: + for dd in data + bin_data0 + [kd]: + if dd in coll.ddata.keys(): + coll.remove_data(dd) + if kr in coll.dref.keys(): + coll.remove_ref(kr) + + for k0 in data: + k1 = [k1 for k1, v1 in ddatan.items() if v1['data'] == k0][0] + dout[k1] = dict(dout[k0]) + del dout[k0] + else: + dout = nobin + + # ---------- + # return + + if returnas is True: + return dout + + +# ###################################################### +# ###################################################### +# check +# ###################################################### + + +def _check_bs( + coll=None, + bin_data0=None, + bin_data1=None, +): + + # ---------------- + # Has bsplines + # ---------------- + + if hasattr(coll, '_which_bsplines'): + + # ---------------- + # list of bsplines + + wbs = coll._which_bsplines + lok_bs = [ + k0 for k0, v0 in coll.dobj.get(wbs, {}).items() + if len(v0['ref']) == 1 + ] + + # ---------------- + # list data with bsplines + + lok_dbs = [ + k0 for k0, v0 in coll.ddata.items() + if v0.get(wbs) is not None + and len(v0[wbs]) == 1 + and v0[wbs][0] in coll.dobj.get(wbs, {}).keys() + and len(coll.dobj[wbs][v0[wbs][0]]['ref']) == 1 + ] + + # ---------------- + # flag whether is bsplines + + c0 = ( + isinstance(bin_data0, str) + and bin_data1 is None + and bin_data0 in lok_dbs + lok_bs + ) + + # ----------------- + # adjust bin_data0 from key_bs to key_apex + + if bin_data0 in lok_bs: + bin_data0 = coll.dobj[wbs][bin_data0]['apex'][0] + + # ---------------- + # Does not have bsplines + # ---------------- + + else: + + c0 = False + + return c0, bin_data0 + + +# ###################################################### +# ###################################################### +# interpolate +# ###################################################### + + +def _interpolate( + coll=None, + data=None, + data_units=None, + # binning + bins0=None, + bin_data0=None, + # options + dref_vector=None, + verb=None, + store=None, + store_keys=None, +): + + # --------- + # sampling + + # mesh knots + wm = coll._which_mesh + wbs = coll._which_bsplines + key_bs = coll.ddata[bin_data0][wbs][0] + keym = coll.dobj[wbs][key_bs][wm] + kknots = coll.dobj[wm][keym]['knots'][0] + + # resolution + vect = coll.ddata[kknots]['data'] + res0 = np.abs(np.min(np.diff(vect))) + + # --------- + # sampling + + ddata = _binning._check_data( + coll=coll, + data=data, + data_units=data_units, + store=True, + ) + lkdata = list(ddata.keys()) + + # -------------------- + # bins + + dbins0 = _binning._check_bins( + coll=coll, + lkdata=lkdata, + bins=bins0, + dref_vector=dref_vector, + store=store, + ) + + # ---------------------- + # npts for interpolation + + dv = np.abs(np.diff(vect)) + dvmean = np.mean(dv) + np.std(dv) + db = np.mean(np.diff(dbins0[lkdata[0]]['edges'])) + npts = (coll.dobj[wbs][key_bs]['deg'] + 3) * max(1, dvmean / db) + 3 + + # sample mesh, update dv + Dx0 = [dbins0[lkdata[0]]['edges'][0], dbins0[lkdata[0]]['edges'][-1]] + xx = coll.get_sample_mesh( + keym, + res=res0 / npts, + mode='abs', + Dx0=Dx0, + )['x0']['data'] + + if xx.size == 0: + nobins = _get_nobins( + coll=coll, + key_bs=key_bs, + ddata=ddata, + dbins0=dbins0, + store=store, + store_keys=store_keys, + ) + return None, None, None, nobins + + # ------------------- + # add ref + + kr = "ntemp" + kd = "xxtemp" + + coll.add_ref(kr, size=xx.size) + coll.add_data(kd, data=xx, ref=kr, units=coll.ddata[kknots]['units']) + + ddata_new = {} + for ii, (k0, v0) in enumerate(ddata.items()): + + # interpolate bin_data + kbdn = f"kbdn{ii}_temp" + # try: + coll.interpolate( + keys=bin_data0, + ref_key=key_bs, + x0=kd, + val_out=0., + returnas=False, + store=True, + inplace=True, + store_keys=kbdn, + ) + + # except Exception as err: + # msg = ( + # err.args[0] + # + "\n\n" + # f"\t- k0 = {k0}\n" + # f"\t- ii = {ii}\n" + # f"\t- bin_data0 = {bin_data0}\n" + # f"\t- key_bs = {key_bs}\n" + # f"\t- kd = {kd}\n" + # f"\t- xx.size: {xx.size}\n" + # f"\t- kbdn = {kbdn}\n" + # ) + # err.args = (msg,) + # raise err + + # interpolate_data + kdn = f"kbd{ii}_temp" + coll.interpolate( + keys=k0, + ref_key=key_bs, + x0=kd, + val_out=0., + returnas=False, + store=True, + inplace=True, + store_keys=kdn, + ) + ddata_new[k0] = {'bin_data': kbdn, 'data': kdn} + + return kr, kd, ddata_new, False + + +def _get_nobins( + coll=None, + key_bs=None, + ddata=None, + dbins0=None, + store=None, + store_keys=None, +): + + lk = list(ddata.keys()) + wbs = coll._which_bsplines + + if isinstance(store_keys, str): + store_keys = [store_keys] + + dout = {} + for ii, k0 in enumerate(lk): + + axis = ddata[k0]['ref'].index(coll.dobj[wbs][key_bs]['ref'][0]) + + shape = list(ddata[k0]['data'].shape) + nb = dbins0[k0]['edges'].size - 1 + shape[axis] = nb + + ref = list(ddata[k0]['ref']) + ref[axis] = dbins0[k0]['bin_ref'][0] + + dout[store_keys[ii]] = { + 'data': np.zeros(shape, dtype=float), + 'ref': tuple(ref), + 'units': ddata[k0]['units'], + } + + if store is True: + for k0, v0 in dout.items(): + coll.add_data(key=k0, **v0) + + return dout diff --git a/datastock/_class1_binning.py b/datastock/_class03_binning.py similarity index 84% rename from datastock/_class1_binning.py rename to datastock/_class03_binning.py index 302781d..e37911f 100644 --- a/datastock/_class1_binning.py +++ b/datastock/_class03_binning.py @@ -29,11 +29,11 @@ # ############################################################ # ############################################################ -# interpolate spectral +# main # ############################################################ -def binning( +def main( coll=None, data=None, data_units=None, @@ -87,7 +87,7 @@ def binning( flag indicating whether binning is used for integration Implies that: Only usable for 1d binning (axis has to be a single index) - data is multiplied by the underlying bin_data0 step prior to binning + data is multiplied by the underlying bin_data0 step before binning statistic: str the statistic kwd feed to scipy.stats.binned_statistic() @@ -99,8 +99,9 @@ def binning( """ - # ---------- - # checks + # --------------------- + # checks inputs + # --------------------- # keys ( @@ -110,15 +111,18 @@ def binning( verb, store, returnas, ) = _check(**locals()) - # -------------- - # actual binning + # ------------------------- + # binning with fixed edges + # ------------------------- if dvariable['bin0'] is False and dvariable['bin1'] is False: dout = {k0: {'units': v0['units']} for k0, v0 in ddata.items()} for k0, v0 in ddata.items(): + # ------------- # handle dbins1 + if dbins1 is None: bins1, vect1, bin_ref1 = None, None, None else: @@ -126,7 +130,9 @@ def binning( vect1 = dbins1['data'] bin_ref1 = dbins1[k0].get('bin_ref') + # ------------ # compute + dout[k0]['data'], dout[k0]['ref'] = _bin_fixed_bin( # data to bin data=v0['data'], @@ -147,6 +153,10 @@ def binning( variable_data=dvariable['data'], ) + # ------------------------- + # binning with variable edges + # ------------------------- + else: msg = ( "Variable bin vectors not implemented yet!\n" @@ -156,8 +166,9 @@ def binning( ) raise NotImplementedError(msg) - # -------------- + # --------------------- # storing + # --------------------- if store is True: @@ -167,16 +178,18 @@ def binning( store_keys=store_keys, ) - # ------------- + # --------------------- # return + # --------------------- if returnas is True: return dout -# #################################### -# check -# #################################### +# ################################################################ +# ################################################################ +# Check inputs +# ################################################################ def _check( @@ -185,11 +198,9 @@ def _check( data_units=None, axis=None, # binning - bins0=None, - bins1=None, + bins=None, bin_data0=None, bin_data1=None, - bin_units0=None, # kind of binning integrate=None, statistic=None, @@ -216,18 +227,45 @@ def _check( default=True, ) + # store + store = _generic_check._check_var( + store, 'store', + types=bool, + default=False, + ) + + # ----------- + # bins + # ------------ + + wb = coll._which_bins + lok_bins = list(coll.dobj.get(wb, {}).keys()) + bins = _generic_check._check_var( + bins, 'bins', + types=str, + allowed=lok_bins, + ) + # ------------------ - # data: str vs array + # data to be binned # ------------------- - ddata = _check_data( + ( + data, bin_data0, bin_data1, nd_bins, units0, units1, + axis, dvariable, + ) = _check_data( coll=coll, data=data, - data_units=data_units, - store=store, + bin_data0=bin_data0, + bin_data1=bin_data1, + bins=bins, + axis=axis, ) - ndim_data = list(ddata.values())[0]['data'].ndim + if data is None: + ndim_data = None + else: + ndim_data = coll.ddata[data]['data'].ndim # ----------------- # check statistic @@ -247,26 +285,6 @@ def _check( # bins # ------------ - dbins0 = _check_bins( - coll=coll, - lkdata=list(ddata.keys()), - bins=bins0, - dref_vector=dref_vector, - store=store, - ) - if bins1 is not None: - dbins1 = _check_bins( - coll=coll, - lkdata=list(ddata.keys()), - bins=bins1, - dref_vector=dref_vector, - store=store, - ) - - # ----------- - # bins - # ------------ - # dbins0 dbins0, variable_bin0, axis = _check_bins_data( coll=coll, @@ -337,7 +355,6 @@ def _check( ) raise Exception(msg) - # ----------------------- # additional safety check @@ -399,20 +416,128 @@ def _check( ) +# ################################################################ +# ################################################################ +# Check data +# ################################################################ + + def _check_data( + coll=None, + data=None, + bin_data0=None, + bin_data1=None, + bins=None, + axis=None, +): + + # --------------- + # get bin features + # --------------- + + wbins = coll._which_bins + nd_bins = int(coll.dobj[wbins][bins]['nd'][0]) + units0 = coll.ddata[coll.dobj[wbins][bins]['edges'][0]] + if nd_bins == 2: + units1 = coll.ddata[coll.dobj[wbins][bins]['edges'][1]] + else: + units1 = None + + # --------------- + # bin_data0 + # --------------- + + lok = list(coll.ddata.keys()) + bin_data0 = _generic_check._check_var( + bin_data0, 'bin_data0', + types=str, + allowed=lok, + ) + + # check units + _check_units( + bin_data=bin_data0, + bin_data_name='bin_data0', + ii=0, + units_bins=units0, + ) + + bin_data_ref = coll.ddata[bin_data0]['ref'] + + # --------------- + # bin_data1 + # --------------- + + if nd_bin == 2: + lok = [ + k0 for k0, v0 in coll.ddata.items() + if v0['ref'] == bin_data_ref + ] + bin_data1 = _generic_check._check_var( + bin_data0, 'bin_data1', + types=str, + allowed=lok, + ) + + # check units + _check_units( + bin_data=bin_data1, + bin_data_name='bin_data1', + ii=1, + units_bins=units1, + ) + + else: + bin_data1 = None + + # --------------- + # data + # --------------- + + if data is not None: + + lok = [ + k0 for k0, v0 in coll.ddata.items() + if tuple([rr for rr in v0['ref'] if rr in bin_data_ref]) == bin_data_ref + ] + data = _generic_check._check_var( + data, 'data', + types=str, + allowed=lok, + ) + + return data, bin_data0, bin_data1, nd_bins, units0, units1 + + +def _check_units(bin_data=None, bin_data_name=None, ii=None, units_bins=None): + units = coll.ddata[bin_data]['units'] + c0 = units is not None and units == units1 + if not c0: + msg = ( + "Binning oddity:\n" + "\t- detected: unmatching 'units' between bins and bin_data\n" + f"\t- Bins: '{bins}' (edges[{ii}])\n" + f"\t- Bins units: '{units_bins}'\n" + f"\t- {bin_data_name}: '{bin_data}'\n" + f"\t- {bin_data_name} units: '{units}'\n" + ) + warnings.warn(msg) + return + + +# DEPRECATED +def _check_data_old( coll=None, data=None, data_units=None, store=None, ): - # ----------- - # store - store = _generic_check._check_var( - store, 'store', - types=bool, - default=False, - ) + # --------------- + # trivial + + if data is None: + return None, store # --------------------- # make sure it's a list @@ -428,7 +553,7 @@ def _check_data( all([ isinstance(dd, str) and dd in coll.ddata.keys() - and coll.ddata[dd]['data'].ndim == coll.ddata[data[0]]['data'].ndim + and coll.ddata[dd]['ref'] == coll.ddata[data[0]]['ref'] for dd in data ]), all([ @@ -444,7 +569,6 @@ def _check_data( msg = "If storing, all data, bin data and bins must be declared!" raise Exception(msg) - # if none => err if np.sum(lc) != 1: msg = ( @@ -482,9 +606,16 @@ def _check_data( for ii in range(len(data)) } - return ddata + return ddata, store + + +# ################################################################ +# ################################################################ +# Check bins +# ################################################################ +# DEPRECATED def _check_bins( coll=None, lkdata=None, @@ -575,6 +706,12 @@ def _check_bins( return dbins +# ################################################################ +# ################################################################ +# Check bins data +# ################################################################ + + def _check_bins_data( coll=None, axis=None, @@ -827,9 +964,10 @@ def _check_bins_data( lim = safety_ratio * dvmean db = np.mean(np.diff(dbins[k0]['edges'])) if db < lim: + ss = f"{db}) are < {safety_ratio} * bin_data ({lim}" msg = ( f"Uncertain binning for bin_data '{v0['key']}':\n" - f"Binning steps ({db}) are < {safety_ratio} * bin_data ({lim}) step" + f"Binning steps ({ss}) step" ) raise Exception(msg) @@ -1016,7 +1154,8 @@ def _bin_fixed_bin( if statistic == 'sum_smooth': val[tuple(sli)] *= ( - np.nansum(data[tuple(sli)]) / np.nansum(val[tuple(sli)]) + np.nansum(data[tuple(sli)]) + / np.nansum(val[tuple(sli)]) ) else: @@ -1039,7 +1178,8 @@ def _bin_fixed_bin( if statistic == 'sum_smooth': val[tuple(sli_val)] *= ( - np.nansum(data[tuple(sli)]) / np.nansum(val[tuple(sli_val)]) + np.nansum(data[tuple(sli)]) + / np.nansum(val[tuple(sli_val)]) ) # --------------- @@ -1073,6 +1213,8 @@ def _bin_fixed_bin( return val, ref + +# ####################################################### # ####################################################### # Store # ####################################################### @@ -1084,7 +1226,6 @@ def _store( store_keys=None, ): - # ---------------- # check store_keys @@ -1093,6 +1234,7 @@ def _store( ldef = [f"{k0}_binned" for k0 in dout.items()] lex = list(coll.ddata.keys()) + store_keys = _generic_check._check_var_iter( store_keys, 'store_keys', types=list, @@ -1110,4 +1252,4 @@ def _store( data=v0['data'], ref=v0['ref'], units=v0['units'], - ) \ No newline at end of file + ) diff --git a/datastock/_class03_checks.py b/datastock/_class03_checks.py new file mode 100644 index 0000000..9f9289c --- /dev/null +++ b/datastock/_class03_checks.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- + + +import numpy as np + + +# Common +from . import _generic_check + + +# ############################################################################# +# ############################################################################# +# bins generic check +# ############################################################################# + + +def check( + coll=None, + key=None, + edges=None, + # custom names + key_edges=None, + key_cents=None, + key_ref_edges=None, + key_ref_cents=None, + # additional attributes + **kwdargs, +): + + # ------------- + # key + # ------------- + + key = _generic_check._obj_key( + d0=coll._dobj.get(coll._which_bins, {}), + short='b', + key=key, + ) + + # ------------ + # edges + # ------------ + + # ----------------------- + # first conformity check + + lc = [ + _check_edges_str(edges, coll), + _check_edges_array(edges), + isinstance(edges, tuple) + and len(edges) in (1, 2) + and all([ + _check_edges_str(ee, coll) or _check_edges_array(ee) + for ee in edges + ]) + ] + + if np.sum(lc) != 1: + msg = ( + f"For Bins '{key}', arg edges must be:\n" + "\t- a str pointing to a n existing monotonous vector\n" + "\t- an array/list/tuple of unique increasing values\n" + "\t- a tuple of 1 or 2 of the above\n" + "Provided:\n\t{edges}" + ) + raise Exception(msg) + + if lc[0] or lc[1]: + edges = (edges,) + + # ---------------------------- + # make tuple of 1d flat arrays + + edges_new = [None for ee in edges] + for ii, ee in enumerate(edges): + if isinstance(ee, str): + edges_new[ii] = ee + else: + edges_new[ii] = _generic_check._check_flat1darray( + ee, f'edges[{ii}]', + dtype=float, + unique=True, + can_be_None=False, + ) + + # --------------------- + # safety check for NaNs + + for ii, ee in enumerate(edges_new): + if isinstance(ee, str): + ee = coll.ddata[ee]['data'] + + isnan = np.any(np.isnan(ee)) + if isnan: + msg = ( + f"Bins '{key}', provided edges have NaNs!\n" + f"\t- edges[{ii}]: {ee}" + ) + raise Exception(msg) + + # -------------- + # wrap up + + edges = edges_new + nd = f"{len(edges)}d" + + # ----------------- + # kwdargs + # ----------------- + + for k0, v0 in kwdargs.items(): + if isinstance(v0, str) or v0 is None: + if nd == '1d': + kwdargs[k0] = (v0,) + else: + kwdargs[k0] = (v0, v0) + + c0 = ( + isinstance(kwdargs[k0], tuple) + and len(kwdargs[k0]) == len(edges) + and all([isinstance(vv, str) or vv is None for vv in kwdargs[k0]]) + ) + if not c0: + msg = ( + f"Bins '{key}', arg kwdargs must be dict of data attributes\n" + "Where each attribute is provided as a tuple of " + f"len() = len(edges) = ({len(edges)})\n" + f"Provided:\n\t{kwdargs}" + ) + raise Exception(msg) + + # ----------------- + # other keys + # ----------------- + + key_edges = _check_keys_ref(key_edges, edges, key, 'key_edges') + key_cents = _check_keys_ref(key_cents, edges, key, 'key_cents') + key_ref_edges = _check_keys_ref(key_ref_edges, edges, key, 'key_ref_edges') + key_ref_cents = _check_keys_ref(key_ref_cents, edges, key, 'key_ref_cents') + + # ----------------- + # edges, cents + # ----------------- + + # ----------------- + # key_ref + + dref = {} + ddata = {} + shape_edges = [None for ee in edges] + is_linear = [None for ee in edges] + is_log = [None for ee in edges] + units = [None for ee in edges] + for ii, ee in enumerate(edges): + ( + key_edges[ii], key_cents[ii], + key_ref_edges[ii], key_ref_cents[ii], + shape_edges[ii], + is_linear[ii], is_log[ii], + units[ii], + ) = _to_dict( + coll=coll, + key=key, + ii=ii, + ee=ee, + # custom names + key_edge=key_edges[ii], + key_cent=key_cents[ii], + key_ref_edge=key_ref_edges[ii], + key_ref_cent=key_ref_cents[ii], + # dict + dref=dref, + ddata=ddata, + # attributes + **{kk: vv[ii] for kk, vv in kwdargs.items()}, + ) + + # ------------- + # ref and shape + + shape_cents = tuple([ss - 1 for ss in shape_edges]) + + # -------------- + # dobj + # -------------- + + # dobj + dobj = { + coll._which_bins: { + key: { + 'nd': nd, + 'edges': tuple(key_edges), + 'cents': tuple(key_cents), + 'ref_edges': tuple(key_ref_edges), + 'ref_cents': tuple(key_ref_cents), + 'shape_edges': tuple(shape_edges), + 'shape_cents': tuple(shape_cents), + 'units': tuple(units), + 'is_linear': tuple(is_linear), + 'is_log': tuple(is_log), + }, + }, + } + + return key, dref, ddata, dobj + + +def _check_edges_str(edges, coll): + return ( + isinstance(edges, str) + and edges in coll.ddata.keys() + and coll.ddata[edges]['monot'] == (True,) + ) + + +def _check_edges_array(edges): + return ( + isinstance(edges, (list, tuple, np.ndarray)) + and all([np.isscalar(ee) for ee in edges]) + and np.array(edges).ndim == 1 + and np.array(edges).size > 1 + ) + + +def _check_keys_ref(keys, edges, key, keys_name): + if keys is None: + keys = [None for ee in edges] + elif isinstance(keys, str): + keys = [keys for ee in edges] + elif isinstance(keys, (list, tuple)): + c0 = ( + len(keys) == len(edges) + and all([isinstance(ss, str) or ss is None for ss in keys]) + ) + if not c0: + msg = ( + f"Bins '{key}', arg '{keys_name}' should be either:\n" + "\t- None (automatically set)\n" + "\t- str to existing key\n" + "\t- tuple of the above of len() = {len(edges)}\n" + "Provided:\n\t{keys}" + ) + raise Exception(msg) + return keys + + +# ############################################################## +# ############################################################### +# to_dict +# ############################################################### + + +def _to_dict( + coll=None, + key=None, + ii=None, + ee=None, + # custom names + key_edge=None, + key_cent=None, + key_ref_edge=None, + key_ref_cent=None, + # dict + dref=None, + ddata=None, + # additional attributes + **kwdargs, +): + """ check key_edge, key_cents, key_ref_edge, key_ref_cent + + If new, append to dref and ddata + """ + + # ------------- + # attributes + # ------------- + + latt = ['dim', 'quant', 'name', 'units'] + dim, quant, name, units = [kwdargs.get(ss) for ss in latt] + + # ------------- + # edges + # ------------- + + # ref + if isinstance(ee, str): + key_edge = ee + ee = coll.ddata[key_edge]['data'] + units = coll.ddata[key_edge]['units'] + + else: + + # ------------------ + # key_ref_edge + + defk = f"{key}_ne{ii}" + lout = [k0 for k0, v0 in coll.dref.items()] + key_ref_edge = _generic_check._check_var( + key_ref_edge, defk, + types=str, + default=defk, + ) + if key_ref_edge in lout: + size = coll.dref[key_ref_edge]['size'] + c0 = size == ee.size + if not c0: + msg = ( + f"Bins '{key}', arg key_ref_edges[{ii}]" + " conflicts with existing ref:\n" + f"\t- coll.dref['{key_ref_edge}']['size'] = {size}" + f"\t- edges['{ii}'].size = {ee.size}\n" + ) + raise Exception(msg) + else: + dref[key_ref_edge] = {'size': ee.size} + + # --------------- + # key_edge + + defk = f"{key}_e{ii}" + lout = [k0 for k0, v0 in coll.ddata.items()] + key_edge = _generic_check._check_var( + key_edge, defk, + types=str, + default=defk, + excluded=lout, + ) + ddata[key_edge] = { + 'data': ee, + 'ref': key_ref_edge, + **kwdargs, + } + + units = kwdargs.get('units') + + # shape + shape_edge = ee.size + + # ------------------ + # is_linear, is_log + # ------------------ + + is_log = ( + np.all(ee > 0.) + and np.allclose(ee[1:] / ee[:-1], ee[1]/ee[0], atol=0, rtol=1e-6) + ) + + is_linear = np.allclose(np.diff(ee), ee[1] - ee[0], atol=0, rtol=1e-6) + assert not (is_log and is_linear), ee + + # ------------ + # cents + # ------------ + + # ------------ + # key_ref_cent + + defk = f"{key}_nc{ii}" + lout = [k0 for k0, v0 in coll.dref.items()] + key_ref_cent = _generic_check._check_var( + key_ref_cent, defk, + types=str, + default=defk, + ) + if key_ref_cent in lout: + size = coll.dref[key_ref_cent]['size'] + c0 = size == (ee.size - 1) + if not c0: + msg = ( + f"Bins '{key}', arg key_ref_cents[{ii}]" + " conflicts with existing ref:\n" + f"\t- coll.dref['{key_ref_edge}']['size'] = {size}" + f"\t- edges['{ii}'].size - 1 = {ee.size-1}\n" + ) + raise Exception(msg) + else: + dref[key_ref_cent] = {'size': ee.size - 1} + + # ------------ + # key_cent + + defk = f"{key}_c{ii}" + lout = [k0 for k0, v0 in coll.ddata.items()] + key_cent = _generic_check._check_var( + key_cent, defk, + types=str, + default=defk, + ) + if key_cent in lout: + ref = coll.ddata[key_cent]['ref'] + c0 = ref == (key_ref_cent,) + if not c0: + msg = ( + f"Bins '{key}', arg key_ref_cents[{ii}]" + " conflicts with existing ref:\n" + f"\t- coll.ddata['{key_ref_cent}']['ref'] = {ref}" + f"\t- key_ref_cent = {key_ref_cent}\n" + ) + raise Exception(msg) + + else: + if is_log: + cents = np.sqrt(ee[:-1] * ee[1:]) + else: + cents = 0.5 * (ee[1:] + ee[:-1]) + + ddata[key_cent] = { + 'data': cents, + 'ref': (key_ref_cent,), + **kwdargs, + } + + return ( + key_edge, key_cent, + key_ref_edge, key_ref_cent, + shape_edge, + is_linear, is_log, + units, + ) + + +# ############################################################## +# ############################################################### +# remove bins +# ############################################################### + + +def remove_bins(coll=None, key=None, propagate=None): + + # ---------- + # check + + # key + wbins = coll._which_bins + if wbins not in coll.dobj.keys(): + return + + if isinstance(key, str): + key = [key] + key = _generic_check._check_var_iter( + key, 'key', + types=(list, tuple), + types_iter=str, + allowed=coll.dobj.get(wbins, {}).keys(), + ) + + # propagate + propagate = _generic_check._check_var( + propagate, 'propagate', + types=bool, + default=True, + ) + + # --------- + # remove + + for k0 in key: + + # specific data + kdata = ( + coll.dobj[wbins][k0]['cents'] + + coll.dobj[wbins][k0]['edges'] + ) + coll.remove_data(kdata, propagate=propagate) + + # specific ref + lref = ( + coll.dobj[wbins][k0]['ref_cents'] + + coll.dobj[wbins][k0]['ref_edges'] + ) + for rr in lref: + if rr in coll.dref.keys(): + coll.remove_ref(rr, propagate=propagate) + + # obj + coll.remove_obj(which=wbins, key=k0, propagate=propagate) diff --git a/datastock/_class3.py b/datastock/_class04_Plots.py similarity index 94% rename from datastock/_class3.py rename to datastock/_class04_Plots.py index 34fe7f2..c6f7632 100644 --- a/datastock/_class3.py +++ b/datastock/_class04_Plots.py @@ -1,14 +1,14 @@ -from ._class2 import * -from . import _plot_as_array -from . import _plot_as_profile1d -from . import _plot_as_mobile_lines -from . import _plot_correlations -from . import _plot_BvsA_as_distribution +from ._class03_Bins import Bins as Previous +from . import _class04_plot_as_array as _plot_as_array +from . import _class04_plot_as_profile1d as _plot_as_profile1d +from . import _class04_plot_as_mobile_lines as _plot_as_mobile_lines +from . import _class04_plot_correlations as _plot_correlations +from . import _class04_plot_BvsA_as_distribution as _plot_BvsA_as_distribution -class DataStock3(DataStock2): +class Plots(Previous): """ Provide default interactive plots """ # ------------------- @@ -414,14 +414,3 @@ def plot_BvsA_as_distribution( # figsize=figsize, dmargin=dmargin, # wintit=wintit, tit=tit, # ) - - -# ############################################################################# -# ############################################################################# -# set __all__ -# ############################################################################# - - -__all__ = [ - sorted([k0 for k0 in locals() if k0.startswith('DataStock')])[-1] -] \ No newline at end of file diff --git a/datastock/_plot_BvsA_as_distribution.py b/datastock/_class04_plot_BvsA_as_distribution.py similarity index 98% rename from datastock/_plot_BvsA_as_distribution.py rename to datastock/_class04_plot_BvsA_as_distribution.py index 89a0723..1a37e02 100644 --- a/datastock/_plot_BvsA_as_distribution.py +++ b/datastock/_class04_plot_BvsA_as_distribution.py @@ -16,9 +16,9 @@ # library-specific from . import _generic_check -from . import _plot_BvsA_as_distribution_check -from . import _plot_text -from . import _class1_compute +from . import _class04_plot_BvsA_as_distribution_check as _plot_BvsA_as_distribution_check +from . import _class04_plot_text as _plot_text +from . import _class01_compute __all__ = ['plot_BvsA_as_distribution'] @@ -187,7 +187,7 @@ def plot_BvsA_as_distribution( if ndim == 1: sli = lambda ind: ind else: - sli = _class1_compute._get_slice(laxis=[1-axis], ndim=2) + sli = _class01_compute._get_slice(laxis=[1-axis], ndim=2) # -------------- # Prepare data diff --git a/datastock/_plot_BvsA_as_distribution_check.py b/datastock/_class04_plot_BvsA_as_distribution_check.py similarity index 100% rename from datastock/_plot_BvsA_as_distribution_check.py rename to datastock/_class04_plot_BvsA_as_distribution_check.py diff --git a/datastock/_plot_as_array.py b/datastock/_class04_plot_as_array.py similarity index 98% rename from datastock/_plot_as_array.py rename to datastock/_class04_plot_as_array.py index 4b6c33f..4e3acce 100644 --- a/datastock/_plot_as_array.py +++ b/datastock/_class04_plot_as_array.py @@ -8,10 +8,10 @@ # library-specific from . import _generic_check -from . import _class1_compute +from . import _class01_compute from . import _generic_utils_plot as _uplot -from . import _plot_as_array_1d -from . import _plot_as_array_234d +from . import _class04_plot_as_array_1d as _plot_as_array_1d +from . import _class04_plot_as_array_234d as _plot_as_array_234d __all__ = ['plot_as_array'] @@ -126,8 +126,7 @@ def plot_as_array( # -------------------------------- if sameref: - from ._class import DataStock - cc = DataStock() + cc = coll.__class__() lk = ['keyX', 'keyY', 'keyZ', 'keyU'] lk = [k0 for k0 in lk if dkeys[k0]['ref'] is not None] for ii, k0 in enumerate(lk): @@ -603,7 +602,6 @@ def _check( else: dvminmax2[k1]['min'] = dvminmax[kk]['min'] - if dvminmax is None or dvminmax.get(kk, {}).get('max') is None: dvminmax2[k1]['max'] = nanmax + margin else: @@ -869,7 +867,7 @@ def get_data_str(dk=None, coll2=None, key=None, ndim=None, dscale=None): dk[k1]['axis'] for k1 in lorder if k1 != k0 and dk[k1]['key'] is not None ] - dk[k0]['sli'] = _class1_compute._get_slice( + dk[k0]['sli'] = _class01_compute._get_slice( laxis=laxis, ndim=ndim, ) diff --git a/datastock/_plot_as_array_1d.py b/datastock/_class04_plot_as_array_1d.py similarity index 99% rename from datastock/_plot_as_array_1d.py rename to datastock/_class04_plot_as_array_1d.py index 1b757c5..543a4a6 100644 --- a/datastock/_plot_as_array_1d.py +++ b/datastock/_class04_plot_as_array_1d.py @@ -9,7 +9,7 @@ # library-specific from . import _generic_check -from . import _plot_text +from . import _class04_plot_text as _plot_text # ############################################################# @@ -295,4 +295,4 @@ def _label_axes( ax.set_xticks([]) ax.set_yticks([]) - return dax \ No newline at end of file + return dax diff --git a/datastock/_plot_as_array_234d.py b/datastock/_class04_plot_as_array_234d.py similarity index 99% rename from datastock/_plot_as_array_234d.py rename to datastock/_class04_plot_as_array_234d.py index ecfaf15..e505af3 100644 --- a/datastock/_plot_as_array_234d.py +++ b/datastock/_class04_plot_as_array_234d.py @@ -9,8 +9,8 @@ # library-specific from . import _generic_check -from . import _class1_compute -from . import _plot_text +from . import _class01_compute +from . import _class04_plot_text as _plot_text # ############################################################# @@ -82,7 +82,7 @@ def sliZ2(*args): elif ndim >= 3: # here slice X => slice in dim Y and vice-versa - sliZ2 = _class1_compute._get_slice( + sliZ2 = _class01_compute._get_slice( laxis=[dkeys[ss]['axis'] for ss in lorder], ndim=ndim, ) diff --git a/datastock/_plot_as_mobile_lines.py b/datastock/_class04_plot_as_mobile_lines.py similarity index 98% rename from datastock/_plot_as_mobile_lines.py rename to datastock/_class04_plot_as_mobile_lines.py index c60cf4a..e5859a7 100644 --- a/datastock/_plot_as_mobile_lines.py +++ b/datastock/_class04_plot_as_mobile_lines.py @@ -10,9 +10,9 @@ # library-specific from . import _generic_check -from . import _plot_text -from . import _class1_compute -from ._plot_as_array import _check_keyXYZ +from . import _class04_plot_text as _plot_text +from . import _class01_compute +from ._class04_plot_as_array import _check_keyXYZ from ._generic_utils_plot import _get_str_datadlab @@ -445,7 +445,7 @@ def _plot_as_mobile_lines2d( # prepare slicing # here slice X and Y alike => slice in dim Y and vice-versa - sli = _class1_compute._get_slice(laxis=[axisch], ndim=2) + sli = _class01_compute._get_slice(laxis=[axisch], ndim=2) # -------------- # plot - prepare @@ -688,7 +688,7 @@ def _plot_as_mobile_lines3d( # reshape into (nt, nch*(npts+1)) ntot = nch*(npts + 1) order = 'C' if axisch < axispts else 'F' - slibck = _class1_compute._get_slice(laxis=[axist], ndim=3) + slibck = _class01_compute._get_slice(laxis=[axist], ndim=3) bckx = np.array([ bckx[slibck(ii)].reshape((ntot,), order=order) for ii in range(nt) @@ -723,7 +723,7 @@ def _plot_as_mobile_lines3d( # prepare slicing # here slice X and Y alike => slice in dim Y and vice-versa - sli = _class1_compute._get_slice(laxis=[axist, axisch], ndim=3) + sli = _class01_compute._get_slice(laxis=[axist, axisch], ndim=3) # -------------- # plot - prepare @@ -962,4 +962,4 @@ def _plot_as_mobile_lines3d( bstr_dict=bstr_dict, ) - return coll, dax, dgroup \ No newline at end of file + return coll, dax, dgroup diff --git a/datastock/_plot_as_profile1d.py b/datastock/_class04_plot_as_profile1d.py similarity index 98% rename from datastock/_plot_as_profile1d.py rename to datastock/_class04_plot_as_profile1d.py index df70e1e..f17b454 100644 --- a/datastock/_plot_as_profile1d.py +++ b/datastock/_class04_plot_as_profile1d.py @@ -10,9 +10,9 @@ # library-specific from . import _generic_check -from . import _plot_text -from . import _class1_compute -from ._plot_as_array import _check_keyXYZ +from . import _class04_plot_text as _plot_text +from . import _class01_compute +from ._class04_plot_as_array import _check_keyXYZ from ._generic_utils_plot import _get_str_datadlab @@ -595,8 +595,8 @@ def _plot_as_profile1d( # prepare slicing # here slice X => slice in dim Y and vice-versa - slit = _class1_compute._get_slice(laxis=[1-axist], ndim=2) - sliX = _class1_compute._get_slice(laxis=[1-axisX], ndim=2) + slit = _class01_compute._get_slice(laxis=[1-axist], ndim=2) + sliX = _class01_compute._get_slice(laxis=[1-axisX], ndim=2) sliXt = _get_sliceXt(laxis=[axist], ndim=dataX.ndim) # -------------- @@ -971,4 +971,4 @@ def _plot_as_profile1d( bstr_dict=bstr_dict, ) - return coll, dax, dgroup \ No newline at end of file + return coll, dax, dgroup diff --git a/datastock/_plot_correlations.py b/datastock/_class04_plot_correlations.py similarity index 98% rename from datastock/_plot_correlations.py rename to datastock/_class04_plot_correlations.py index d19be76..70eee1d 100644 --- a/datastock/_plot_correlations.py +++ b/datastock/_class04_plot_correlations.py @@ -15,9 +15,9 @@ # library-specific from . import _generic_check -from . import _plot_BvsA_as_distribution_check -from . import _plot_text -from . import _class2_interactivity +from . import _class04_plot_BvsA_as_distribution_check as _plot_BvsA_as_distribution_check +from . import _class04_plot_text as _plot_text +from . import _class02_interactivity __all__ = ['plot_correlations'] diff --git a/datastock/_plot_text.py b/datastock/_class04_plot_text.py similarity index 100% rename from datastock/_plot_text.py rename to datastock/_class04_plot_text.py diff --git a/datastock/_direct_calls.py b/datastock/_direct_calls.py index eadc214..d5be8fc 100644 --- a/datastock/_direct_calls.py +++ b/datastock/_direct_calls.py @@ -4,7 +4,7 @@ # library-specific -from ._class import DataStock +from ._class04_Plots import Plots as Collection __all__ = [ @@ -39,7 +39,7 @@ def plot_as_array(data=None): # --------------------- # Instanciate datastock - st = DataStock() + st = Collection() st.add_data(key='data', data=data) return st.plot_as_array(inplace=True) @@ -67,7 +67,7 @@ def plot_BvsA_as_distribution(dataA=None, dataB=None): # --------------------- # Instanciate datastock - st = DataStock() + st = Collection() st.add_data(key='dataA', data=dataA) st.add_data(key='dataB', data=dataB) diff --git a/datastock/_saveload.py b/datastock/_saveload.py index 187991f..25167d9 100644 --- a/datastock/_saveload.py +++ b/datastock/_saveload.py @@ -125,8 +125,8 @@ def load( # cls if cls is None: - from ._class import DataStock - cls = DataStock + from ._class04_Plots import Plots as Collection + cls = Collection if not (type(cls) is type and hasattr(cls, 'from_dict')): msg = ( @@ -168,6 +168,10 @@ def load( # ---------- # reshape + # sparse types + lsparse = ['csc_', 'bsr_', 'coo_', 'csr_', 'dia_', 'dok_', 'lil_'] + + # loop dout = {} for k0, v0 in dflat.items(): @@ -201,7 +205,7 @@ def load( dout[k0] = None elif typ == 'ndarray': dout[k0] = dflat[k0] - elif any([ss in typ for ss in ['csc_', 'bsr_', 'coo_', 'csr_', 'dia_', 'dok_', 'lil_']]): + elif any([ss in typ for ss in lsparse]): assert typ in type(dflat[k0]).__name__ dout[k0] = dflat[k0] elif 'Unit' in typ: @@ -276,8 +280,10 @@ def get_files( lc = [ isinstance(dpfe, (str, tuple)), - isinstance(dpfe, list) and all([isinstance(pp, (str, tuple)) for pp in dpfe]), - isinstance(dpfe, dict) and all([isinstance(pp, str) for pp in dpfe.keys()]) + isinstance(dpfe, list) + and all([isinstance(pp, (str, tuple)) for pp in dpfe]), + isinstance(dpfe, dict) + and all([isinstance(pp, str) for pp in dpfe.keys()]) ] if not any(lc): @@ -288,7 +294,7 @@ def get_files( "\t\tkeys = valid path str\n" "\t\tvalues =\n" "\t\t\t- str: valid file names in the associated path\n" - "\t\t\t- str: pattern to be found in the files names in that path\n" + "\t\t\t- str: pattern to be found in the files names in path\n" "\t\t\t- list of str: list of the above (file names or patterns)\n" ) raise Exception(msg) @@ -406,7 +412,10 @@ def _get_files_from_path( lc = [ any([os.path.isfile(pfe) for pfe in lpfe if isinstance(pfe, str)]), - any([os.path.isfile(os.path.join(path, pfe)) for pfe in lpfe if isinstance(pfe, str)]), + any([ + os.path.isfile(os.path.join(path, pfe)) + for pfe in lpfe if isinstance(pfe, str) + ]), ] # --------------------- @@ -469,4 +478,4 @@ def _get_files_from_path( else: warnings.warn(msg) - return out \ No newline at end of file + return out diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 0ebd01e..943b20e 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -6,7 +6,6 @@ # Built-in import os -import warnings # Standard @@ -14,8 +13,9 @@ import matplotlib.pyplot as plt # datastock-specific -from .._class import DataStock +from .._class04_Plots import Plots as Collection from .._saveload import load +from . import test_input as _input _PATH_HERE = os.path.dirname(__file__) @@ -70,7 +70,10 @@ def _add_data(st=None, nc=None, nx=None, lnt=None): ne = np.logspace(15, 21, 11) Te = np.logspace(1, 5, 21) - pec = np.exp(-(ne[:, None] - 1e18)**2/1e5**2 - (Te[None, :] - 5e3)**2/3e3**2) + pec = np.exp( + -(ne[:, None] - 1e18)**2/1e5**2 + - (Te[None, :] - 5e3)**2/3e3**2 + ) lt = [np.linspace(1, 10, nt) for nt in lnt] lprof = [(1 + np.cos(t)[:, None]) * x[None, :] for t in lt] @@ -210,7 +213,7 @@ class Test01_Instanciate(): @classmethod def setup_class(cls): - cls.st = DataStock() + cls.coll = Collection() cls.nc = 5 cls.nx = 80 cls.lnt = [100, 90, 80, 120, 80] @@ -220,13 +223,13 @@ def setup_class(cls): # ------------------------ def test01_add_ref(self): - _add_ref(st=self.st, nc=self.nc, nx=self.nx, lnt=self.lnt) + _add_ref(st=self.coll, nc=self.nc, nx=self.nx, lnt=self.lnt) def test02_add_data(self): - _add_data(st=self.st, nc=self.nc, nx=self.nx, lnt=self.lnt) + _add_data(st=self.coll, nc=self.nc, nx=self.nx, lnt=self.lnt) def test03_add_obj(self): - _add_obj(st=self.st, nc=self.nc) + _add_obj(st=self.coll, nc=self.nc) ####################################################### @@ -240,14 +243,14 @@ class Test02_Manipulate(): @classmethod def setup_class(cls): - cls.st = DataStock() + cls.coll = Collection() cls.nc = 5 cls.nx = 80 cls.lnt = [100, 90, 80, 120, 80] - _add_ref(st=cls.st, nc=cls.nc, nx=cls.nx, lnt=cls.lnt) - _add_data(st=cls.st, nc=cls.nc, nx=cls.nx, lnt=cls.lnt) - _add_obj(st=cls.st, nc=cls.nc) + _add_ref(st=cls.coll, nc=cls.nc, nx=cls.nx, lnt=cls.lnt) + _add_data(st=cls.coll, nc=cls.nc, nx=cls.nx, lnt=cls.lnt) + _add_obj(st=cls.coll, nc=cls.nc) # ------------------------ # Add / remove @@ -255,17 +258,17 @@ def setup_class(cls): def test01_add_param(self): # create new 'campaign' parameter for data arrays - self.st.add_param('campaign', which='data') + self.coll.add_param('campaign', which='data') # tag each data with its campaign for ii in range(self.nc): - self.st.set_param( + self.coll.set_param( which='data', key=f't{ii}', param='campaign', value=f'c{ii}', ) - self.st.set_param( + self.coll.set_param( which='data', key=f'prof{ii}', param='campaign', @@ -273,38 +276,38 @@ def test01_add_param(self): ) def test02_remove_param(self): - self.st.add_param('blabla', which='campaign') - self.st.remove_param('blabla', which='campaign') + self.coll.add_param('blabla', which='campaign') + self.coll.remove_param('blabla', which='campaign') # ------------------------ # Selection / sorting # ------------------------ def test03_select(self): - key = self.st.select(which='data', units='s', returnas=str) + key = self.coll.select(which='data', units='s', returnas=str) assert key.tolist() == ['t0', 't1', 't2', 't3', 't4'] - out = self.st.select(dim='time', returnas=int) + out = self.coll.select(dim='time', returnas=int) assert len(out) == 5, out # test quantitative param selection - out = self.st.select(which='campaign', index=[2, 4]) + out = self.coll.select(which='campaign', index=[2, 4]) assert len(out) == 3 - out = self.st.select(which='campaign', index=(2, 4)) + out = self.coll.select(which='campaign', index=(2, 4)) assert len(out) == 2 def test04_sortby(self): - self.st.sortby(which='data', param='units') + self.coll.sortby(which='data', param='units') # ------------------------ # show # ------------------------ def test05_show(self): - self.st.show() - self.st.show_data() - self.st.show_obj() + self.coll.show() + self.coll.show_data() + self.coll.show_obj() # ------------------------ # Interpolate @@ -315,7 +318,7 @@ def test06_get_ref_vector(self): hasref, hasvector, ref, key_vector, values, dind, - ) = self.st.get_ref_vector( + ) = self.coll.get_ref_vector( key='prof0', ref='nx', values=[1, 2, 2.01, 3], @@ -326,7 +329,7 @@ def test06_get_ref_vector(self): assert dind['indr'].shape == (2, 4) def test07_get_ref_vector_common(self): - hasref, ref, key, val, dout = self.st.get_ref_vector_common( + hasref, ref, key, val, dout = self.coll.get_ref_vector_common( keys=['t0', 'prof0', 'prof1', 't3'], dim='time', ) @@ -339,111 +342,25 @@ def test08_domain_ref(self): 'y': [[0, 0.9], (0.1, 0.2)], 't0': {'domain': [2, 3]}, 't1': {'domain': [[2, 3], (2.5, 3), [4, 6]]}, - 't2': {'ind': self.st.ddata['t2']['data'] > 5}, + 't2': {'ind': self.coll.ddata['t2']['data'] > 5}, } - dout = self.st.get_domain_ref(domain=domain) + dout = self.coll.get_domain_ref(domain=domain) lk = list(domain.keys()) assert all([isinstance(dout[k0]['ind'], np.ndarray) for k0 in lk]) - def test09_binning(self): - - bins = np.linspace(1, 5, 8) - lk = [ - ('y', 'nx', bins, 0, False, False, 'y_bin0'), - ('y', 'nx', bins, 0, True, False, 'y_bin1'), - ('y', 'nx', 'x', 0, False, True, 'y_bin2'), - ('y', 'nx', 'x', 0, True, True, 'y_bin3'), - ('prof0', 'x', 'nt0', 1, False, True, 'p0_bin0'), - ('prof0', 'x', 'nt0', 1, True, True, 'p0_bin1'), - ('prof0-bis', 'prof0', 'x', [0, 1], False, True, 'p1_bin0'), - ] - - for ii, (k0, kr, kb, ax, integ, store, kbin) in enumerate(lk): - dout = self.st.binning( - data=k0, - bin_data0=kr, - bins0=kb, - axis=ax, - integrate=integ, - store=store, - store_keys=kbin, - safety_ratio=0.95, - returnas=True, - ) + def test09_add_bins(self): + _input.add_bins(self.coll) - if np.isscalar(ax): - ax = [ax] - - if isinstance(kb, str): - if kb in self.st.ddata: - nb = self.st.ddata[kb]['data'].size - else: - nb = self.st.dref[kb]['size'] - else: - nb = bins.size - - k0 = list(dout.keys())[0] - shape = [ - ss for ii, ss in enumerate(self.st.ddata[k0]['data'].shape) - if ii not in ax - ] - - shape.insert(ax[0], nb) - if dout[k0]['data'].shape != tuple(shape): - msg = ( - "Mismatching shapes for case {ii}!\n" - f"\t- dout['{k0}']['data'].shape = {dout[k0]['data'].shape}\n" - f"\t- expected: {tuple(shape)}" - ) - raise Exception(msg) + def test10_binning(self): + # _input.binning(self.coll) + pass - def test10_interpolate(self): + def test11_interpolate(self): + _input.interpolate(self.coll) - lk = ['y', 'y', 'prof0', 'prof0', 'prof0', '3d'] - lref = [None, 'nx', 't0', ['nt0', 'nx'], ['t0', 'x'], ['t0', 'x']] - lax = [[0], [0], [0], [0, 1], [0, 1], [1, 2]] - lgrid = [False, False, False, False, True, False] - llog = [False, False, False, True, False, False] - - x2d = np.array([[1.5, 2.5], [1, 2]]) - x3d = np.random.random((5, 4, 3)) - lx0 = [x2d, [1.5, 2.5], [1.5, 2.5], x2d, [1.5, 2.5], x3d] - lx1 = [None, None, None, x2d, [1.2, 2.3], x3d] - ldom = [None, None, {'nx': [1.5, 2]}, None, None, None] - - zipall = zip(lk, lref, lax, llog, lgrid, lx0, lx1, ldom) - for ii, (kk, rr, aa, lg, gg, x0, x1, dom) in enumerate(zipall): - - domain = self.st.get_domain_ref(domain=dom) - - dout = self.st.interpolate( - keys=kk, - ref_key=rr, - x0=x0, - x1=x1, - grid=gg, - deg=2, - deriv=None, - log_log=lg, - return_params=False, - domain=dom, - ) - - assert isinstance(dout, dict) - assert isinstance(dout[kk]['data'], np.ndarray) - shape = list(self.st.ddata[kk]['data'].shape) - x0s = np.array(x0).shape if gg is False else (len(x0), len(x1)) - if dom is None: - shape = tuple(np.r_[shape[:aa[0]], x0s, shape[aa[-1]+1:]]) - else: - shape = tuple(np.r_[x0s, 39]) if ii == 2 else None - if dout[kk]['data'].shape != tuple(shape): - msg = str(dout[kk]['data'].shape, shape, kk, rr) - raise Exception(msg) - - def test11_interpolate_common_refs(self): + def test12_interpolate_common_refs(self): lk = ['3d', '3d', '3d'] lref = ['t0', ['nt0', 'nx'], ['nx']] lrefc = ['nc', 'nc', 'nt0'] @@ -451,10 +368,10 @@ def test11_interpolate_common_refs(self): llog = [False, True, False] # add data for common ref interpolation - nt0 = self.st.dref['nt0']['size'] - nt1 = self.st.dref['nt1']['size'] - nc = self.st.dref['nc']['size'] - self.st.add_data( + nt0 = self.coll.dref['nt0']['size'] + nt1 = self.coll.dref['nt1']['size'] + nc = self.coll.dref['nc']['size'] + self.coll.add_data( key='data_com', data=1. + np.random.random((nc, nt1, nt0))*2, ref=('nc', 'nt1', 'nt0'), @@ -471,7 +388,7 @@ def test11_interpolate_common_refs(self): zipall = zip(lk, lref, lax, llog, lx1, lrefc, ls, lr) for ii, (kk, rr, aa, lg, x1, refc, ss, ri) in enumerate(zipall): - dout, dparams = self.st.interpolate( + dout, dparams = self.coll.interpolate( keys=kk, ref_key=rr, x0='data_com', @@ -490,7 +407,7 @@ def test11_interpolate_common_refs(self): assert isinstance(dout[kk]['data'], np.ndarray) if not (dout[kk]['data'].shape == ss and dout[kk]['ref'] == ri): - lstr = [f'\t- {k0}: {v0}' for k0, v0 in dparams.items()] + # lstr = [f'\t- {k0}: {v0}' for k0, v0 in dparams.items()] msg = ( "Wrong interpolation shape / ref:\n" f"\t- ii: {ii}\n" @@ -499,8 +416,8 @@ def test11_interpolate_common_refs(self): f"\t- x1: {x1}\n" f"\t- ref_com: {refc}\n" f"\t- log_log: {lg}\n" - f"\t- key['ref']: {self.st.ddata[kk]['ref']}\n" - f"\t- x0['ref']: {self.st.ddata['data_com']['ref']}\n" + f"\t- key['ref']: {self.coll.ddata[kk]['ref']}\n" + f"\t- x0['ref']: {self.coll.ddata['data_com']['ref']}\n" "\n" # + "\n".join(lstr) "\n" @@ -512,7 +429,6 @@ def test11_interpolate_common_refs(self): ) raise Exception(msg) - # Not tested: float, store=True, inplace # ------------------------ @@ -520,17 +436,17 @@ def test11_interpolate_common_refs(self): # ------------------------ def test12_plot_as_array_1d(self): - dax = self.st.plot_as_array(key='t0') + dax = self.coll.plot_as_array(key='t0') plt.close('all') del dax def test13_plot_as_array_2d(self): - dax = self.st.plot_as_array(key='prof0') + dax = self.coll.plot_as_array(key='prof0') plt.close('all') del dax def test14_plot_as_array_2d_log(self): - dax = self.st.plot_as_array( + dax = self.coll.plot_as_array( key='pec', keyX='ne', keyY='Te', dscale={'data': 'log'}, ) @@ -538,27 +454,29 @@ def test14_plot_as_array_2d_log(self): del dax def test15_plot_as_array_3d(self): - dax = self.st.plot_as_array(key='3d', dvminmax={'keyX': {'min': 0}}) + dax = self.coll.plot_as_array(key='3d', dvminmax={'keyX': {'min': 0}}) plt.close('all') del dax def test16_plot_as_array_3d_ZNonMonot(self): - dax = self.st.plot_as_array(key='3d', keyZ='y') + dax = self.coll.plot_as_array(key='3d', keyZ='y') plt.close('all') del dax def test17_plot_as_array_4d(self): - dax = self.st.plot_as_array(key='4d', dscale={'keyU': 'linear'}) + dax = self.coll.plot_as_array(key='4d', dscale={'keyU': 'linear'}) plt.close('all') del dax # def test18_plot_BvsA_as_distribution(self): - # dax = self.st.plot_BvsA_as_distribution(keyA='prof0', keyB='prof0-bis') + # dax = self.coll.plot_BvsA_as_distribution( + # keyA='prof0', keyB='prof0-bis', + # ) # plt.close('all') # del dax def test19_plot_as_profile1d(self): - dax = self.st.plot_as_profile1d( + dax = self.coll.plot_as_profile1d( key='prof0', key_time='t0', keyX='prof0-bis', @@ -570,7 +488,7 @@ def test19_plot_as_profile1d(self): # def test20_plot_as_mobile_lines(self): # # 3d - # dax = self.st.plot_as_mobile_lines( + # dax = self.coll.plot_as_mobile_lines( # keyX='3d', # keyY='3d-bis', # key_time='t0', @@ -578,7 +496,7 @@ def test19_plot_as_profile1d(self): # ) # # 2d - # dax = self.st.plot_as_mobile_lines( + # dax = self.coll.plot_as_mobile_lines( # keyX='prof2', # keyY='prof2-bis', # key_chan='nx', @@ -592,21 +510,21 @@ def test19_plot_as_profile1d(self): # ------------------------ def test21_copy_equal(self): - st2 = self.st.copy() - assert st2 is not self.st + st2 = self.coll.copy() + assert st2 is not self.coll - msg = st2.__eq__(self.st, returnas=str) + msg = st2.__eq__(self.coll, returnas=str) if msg is not True: raise Exception(msg) def test22_get_nbytes(self): - nb, dnb = self.st.get_nbytes() + nb, dnb = self.coll.get_nbytes() def test23_saveload(self, verb=False): - pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) + pfe = self.coll.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) st2 = load(pfe, verb=verb) # Just to check the loaded version works fine - msg = st2.__eq__(self.st, returnas=str) + msg = st2.__eq__(self.coll, returnas=str) if msg is not True: raise Exception(msg) - os.remove(pfe) \ No newline at end of file + os.remove(pfe) diff --git a/datastock/tests/test_input.py b/datastock/tests/test_input.py new file mode 100644 index 0000000..035aab1 --- /dev/null +++ b/datastock/tests/test_input.py @@ -0,0 +1,180 @@ + + +import numpy as np + + +# ############################################################### +# ############################################################### +# Add bins +# ############################################################### + + +def add_bins(coll): + + # --------------- + # check if needed + # --------------- + + wbins = coll._which_bins + if coll.dobj.get(wbins) is not None: + return + + # ------------------------- + # define bins from scratch + # ------------------------- + + # linear uniform 1d + coll.add_bins('b1d_lin', edges=np.linspace(0, 1, 10), units='m') + + # log uniform 1d + coll.add_bins('b1d_log', edges=np.logspace(0, 1, 10), units='eV') + + # non-uniform 1d + coll.add_bins('b2d_rand', edges=np.r_[1, 2, 5, 10, 12, 20], units='s') + + # linear uniform 2d + coll.add_bins( + 'b2d_lin', + edges=(np.linspace(0, 1, 10), np.linspace(0, 3, 20)), + units='m', + ) + + # log uniform mix 2d + coll.add_bins( + 'b2d_mix', + edges=(np.logspace(0, 1, 10), np.pi*np.r_[0, 0.5, 1, 1.2, 1.5, 2]), + units=('eV', 'rad'), + ) + + # ------------------------- + # define bins pre-existing + # ------------------------- + + return + + +# ############################################################### +# ############################################################### +# Binning +# ############################################################### + + +def binning(coll): + + # --------------- + # check if needed + # --------------- + + wbins = coll._which_bins + if coll.dobj.get(wbins) is None: + add_bins(coll) + + # ------------------- + # Binning + # ------------------- + + bins = np.linspace(1, 5, 8) + lk = [ + ('y', 'nx', bins, 0, False, False, 'y_bin0'), + ('y', 'nx', bins, 0, True, False, 'y_bin1'), + ('y', 'nx', 'x', 0, False, True, 'y_bin2'), + ('y', 'nx', 'x', 0, True, True, 'y_bin3'), + ('prof0', 'x', 'nt0', 1, False, True, 'p0_bin0'), + ('prof0', 'x', 'nt0', 1, True, True, 'p0_bin1'), + ('prof0-bis', 'prof0', 'x', [0, 1], False, True, 'p1_bin0'), + ] + + for ii, (k0, kr, kb, ax, integ, store, kbin) in enumerate(lk): + dout = coll.binning( + data=k0, + bin_data0=kr, + bins0=kb, + axis=ax, + integrate=integ, + store=store, + store_keys=kbin, + safety_ratio=0.95, + returnas=True, + ) + + if np.isscalar(ax): + ax = [ax] + + if isinstance(kb, str): + if kb in coll.ddata: + nb = coll.ddata[kb]['data'].size + else: + nb = coll.dref[kb]['size'] + else: + nb = bins.size + + k0 = list(dout.keys())[0] + shape = [ + ss for ii, ss in enumerate(coll.ddata[k0]['data'].shape) + if ii not in ax + ] + + shape.insert(ax[0], nb) + if dout[k0]['data'].shape != tuple(shape): + shstr = dout[k0]['data'].shape + msg = ( + "Mismatching shapes for case {ii}!\n" + f"\t- dout['{k0}']['data'].shape = {shstr}\n" + f"\t- expected: {tuple(shape)}" + ) + raise Exception(msg) + + return + + +# ############################################################### +# ############################################################### +# Interpolate +# ############################################################### + + +def interpolate(coll): + + lk = ['y', 'y', 'prof0', 'prof0', 'prof0', '3d'] + lref = [None, 'nx', 't0', ['nt0', 'nx'], ['t0', 'x'], ['t0', 'x']] + lax = [[0], [0], [0], [0, 1], [0, 1], [1, 2]] + lgrid = [False, False, False, False, True, False] + llog = [False, False, False, True, False, False] + + x2d = np.array([[1.5, 2.5], [1, 2]]) + x3d = np.random.random((5, 4, 3)) + lx0 = [x2d, [1.5, 2.5], [1.5, 2.5], x2d, [1.5, 2.5], x3d] + lx1 = [None, None, None, x2d, [1.2, 2.3], x3d] + ldom = [None, None, {'nx': [1.5, 2]}, None, None, None] + + zipall = zip(lk, lref, lax, llog, lgrid, lx0, lx1, ldom) + for ii, (kk, rr, aa, lg, gg, x0, x1, dom) in enumerate(zipall): + + _ = coll.get_domain_ref(domain=dom) + + dout = coll.interpolate( + keys=kk, + ref_key=rr, + x0=x0, + x1=x1, + grid=gg, + deg=2, + deriv=None, + log_log=lg, + return_params=False, + domain=dom, + ) + + assert isinstance(dout, dict) + assert isinstance(dout[kk]['data'], np.ndarray) + shape = list(coll.ddata[kk]['data'].shape) + x0s = np.array(x0).shape if gg is False else (len(x0), len(x1)) + if dom is None: + shape = tuple(np.r_[shape[:aa[0]], x0s, shape[aa[-1]+1:]]) + else: + shape = tuple(np.r_[x0s, 39]) if ii == 2 else None + if dout[kk]['data'].shape != tuple(shape): + msg = str(dout[kk]['data'].shape, shape, kk, rr) + raise Exception(msg) + + return