diff --git a/pbjam/__init__.py b/pbjam/__init__.py index e2c2e40..0e7c357 100644 --- a/pbjam/__init__.py +++ b/pbjam/__init__.py @@ -4,7 +4,20 @@ import os PACKAGEDIR = os.path.abspath(os.path.dirname(__file__)) +# Setup global pbjam logger +import logging +logger = logging.getLogger(__name__) +logger.setLevel('DEBUG') # <--- minimum level for global pbjam package logger + +# Setup console handler +from .jar import _stream_handler +console_handler = _stream_handler(level='INFO') +logger.addHandler(console_handler) +logger.debug(f'Initializing {__name__}') + from .version import __version__ +logger.debug(f'version == {__version__}') + from .priors import kde from .session import session from .asy_peakbag import asymp_spec_model, asymptotic_fit @@ -12,4 +25,6 @@ from .ellone import ellone from .star import star from .mcmc import mcmc -from .mcmc import nested \ No newline at end of file +from .mcmc import nested + +logger.debug(f'Initialized {__name__}') diff --git a/pbjam/asy_peakbag.py b/pbjam/asy_peakbag.py index 37b05be..d1f715c 100755 --- a/pbjam/asy_peakbag.py +++ b/pbjam/asy_peakbag.py @@ -12,9 +12,13 @@ import pandas as pd import scipy.stats as scist from .plotting import plotting -from .jar import normal +from .jar import normal, debug from collections import OrderedDict -import warnings +import warnings, logging + +logger = logging.getLogger(__name__) +debugger = debug(logger) + class asymp_spec_model(): """Class for spectrum model using asymptotic relation. @@ -35,7 +39,7 @@ class asymp_spec_model(): Number of radial order to fit. """ - + # @debugger def __init__(self, f, norders): self.f = np.array([f]).flatten() self.norders = int(norders) @@ -326,7 +330,7 @@ class asymptotic_fit(plotting, asymp_spec_model): science results! """ - + # @debugger def __init__(self, st, norders=None): self.pg = st.pg @@ -354,7 +358,11 @@ def __init__(self, st, norders=None): self.path = st.path st.asy_fit = self - + + def __repr__(self): + return f'' + + @debugger def __call__(self, method, developer_mode): """ Setup, run and parse the asymptotic relation fit. @@ -380,7 +388,7 @@ def __call__(self, method, developer_mode): self.developer_mode = developer_mode if method not in ['emcee', 'cpnest']: - warnings.warn(f'Method {method} not found: Using method emcee') + logger.warning(f'Method {method} not found: Using default method emcee') method = 'emcee' if method == 'emcee': @@ -400,7 +408,6 @@ def __call__(self, method, developer_mode): return {'modeID': self.modeID, 'summary': self.summary} - def prior(self, p): """ Calculates the log prior @@ -514,7 +521,7 @@ def _get_summary_stats(self, fit): return summary - + @debugger def get_modeIDs(self, fit, norders): """ Set mode ID in a dataframe diff --git a/pbjam/data/pbjam_references.bib b/pbjam/data/pbjam_references.bib index 29a0b29..23c3776 100644 --- a/pbjam/data/pbjam_references.bib +++ b/pbjam/data/pbjam_references.bib @@ -1,5 +1,5 @@ -@article{nested, +@article{cpnest, title={johnveitch/cpnest: Minor optimisation}, DOI={10.5281/zenodo.835874}, publisher={Zenodo}, diff --git a/pbjam/ellone.py b/pbjam/ellone.py index c482978..1d29879 100644 --- a/pbjam/ellone.py +++ b/pbjam/ellone.py @@ -31,10 +31,15 @@ from sklearn.preprocessing import MinMaxScaler from sklearn.utils import shuffle as skshuffle import hdbscan as Hdbscan -import warnings +import warnings, logging from .plotting import plotting import astropy.units as units import lightkurve as lk +from .jar import debug + +logger = logging.getLogger(__name__) +debugger = debug(logger) + class ellone(plotting): """ Basic l=1 detection @@ -71,7 +76,7 @@ class ellone(plotting): instead, in which case the l=2,0 modes may be picked up instead of the l=1. """ - + def __init__(self, pbinst=None, f=None, s=None): if pbinst: @@ -101,7 +106,8 @@ def __init__(self, pbinst=None, f=None, s=None): self.hdblabels = None self.hdbX = None self.hdb_clusterN = None - + + @debugger def residual(self,): """ Compute the residual after dividing out l=2,0 @@ -129,7 +135,7 @@ def residual(self,): idx = (flad[0] <= self.f) & (self.f <= flad[-1]) res[idx] /= mod[i,:] return res - + def binning(self, nbin): """ Simply mean-binning @@ -177,7 +183,7 @@ def H0test(self, fbin, sbin, nbin, dnu, reject=0.1): idx = k < reject return idx, k - + @debugger def H0_inconsistent(self, dnu, Nmax, rejection_level): """ Find bins inconsistent with noise @@ -220,6 +226,7 @@ def H0_inconsistent(self, dnu, Nmax, rejection_level): return nu, N, pH0s + @debugger def clustering_preprocess(self, nu, N, limits = (0, 100000)): """ Preprocess the samples before clustering @@ -268,6 +275,7 @@ def span(self, x): return max(x)-min(x) + @debugger def clustering(self, nu, N, Nmax, outlier_limit=0.5, cluster_prob=0.9): """ Perform HDBscan clustering @@ -326,8 +334,7 @@ def clustering(self, nu, N, Nmax, outlier_limit=0.5, cluster_prob=0.9): return nus[1:], nstds[1:] - - + @debugger def get_ell1(self, dnu): """ Estimate frequency of l=1 modes (p-modes) @@ -376,10 +383,11 @@ def get_ell1(self, dnu): nul1s_std[i] = self.cluster_stds[nuidx][maxidx] if (nul0s[i] - nul1s[i])/d01 > 0.2: - warnings.warn('Cluster nu_l1 exceeds UP estimate by more than 20%') + logger.warning('Cluster nu_l1 exceeds UP estimate by more than 20%') return nul1s, nul1s_std + @debugger def __call__(self, dnu, Nmax = 30, rejection_level = 0.1): """ Perform all the steps to estimate l=1 frequencies diff --git a/pbjam/jar.py b/pbjam/jar.py index 7848389..a8cb9da 100644 --- a/pbjam/jar.py +++ b/pbjam/jar.py @@ -7,10 +7,334 @@ from . import PACKAGEDIR import os import numpy as np +import pandas as pd from scipy.special import erf +import functools, logging, inspect, sys, warnings +from .printer import pretty_printer + +HANDLER_FMT = "%(asctime)-23s :: %(levelname)-8s :: %(name)-17s :: %(message)s" +INDENT = 60 # Set to length of logger info before message or just indent by 2? +logger = logging.getLogger(__name__) + +_pp_kwargs = {'width': 120} +if sys.version_info[0] == 3 and sys.version_info[1] >= 8: + # 'sort_dicts' kwarg new to Python 3.8 + _pp_kwargs['sort_dicts'] = False + +pprinter = pretty_printer(**_pp_kwargs) + + +class _function_logger: + """ Handlers the logging upon entering and exiting functions. """ + + def __init__(self, func, logger): + self.func = func + self.signature = inspect.signature(self.func) + self.logger = logger + + def _log_bound_args(self, args, kwargs): + """ Logs bound arguments - ``args`` and ``kwargs`` passed to func. """ + bargs = self.signature.bind(*args, **kwargs) + bargs_dict = dict(bargs.arguments) + self.logger.debug(f"Bound arguments:\n{pprinter.pformat(bargs_dict)}") + + def _entering_function(self, args, kwargs): + """ Log before function execution. """ + self.logger.debug(f"Entering {self.func.__qualname__}") + self.logger.debug(f"Signature:\n{self.func.__name__ + str(self.signature)}") + self._log_bound_args(args, kwargs) + # TODO: stuff to check before entering function + + def _exiting_function(self, result): + """ Log after function execution. """ + # TODO: stuff to check before exiting function + if result is not None: + self.logger.debug(f"Returns:\n{pprinter.pformat(result)}") + self.logger.debug(f"Exiting {self.func.__qualname__}") + + +def debug(logger): + """ + Function logging decorator. Logs function metadata upon entering and + exiting. + + Parameters + ---------- + logger: logging.Logger + Specify the logger in which to submit entering and exiting logs, highly + recommended to be the module-level logger (see Examples). + + Examples + -------- + Logging a function called ``my_func`` defined in a module with name ``__name__``, + + .. code-block:: python + + import logging + from pbjam.jar import debug + + logger = logging.getLogger(__name__) + debugger = debug(logger) + + @debugger + def my_func(a, b): + logger.debug('Function in progress.') + return a + b + + if __name__ == "__main__": + logging.basicConfig() + logger.setLevel('DEBUG') + + result = my_func(1, 2) + logger.debug(f'result = {result}') + + Outputs, + + .. code-block:: text + + DEBUG:__main__:Entering my_func + DEBUG:__main__:Function in progress. + DEBUG:__main__:Exiting my_func + DEBUG:__main__:result = 3 + + For use within classes, + + .. code-block:: python + + import logging + from pbjam.jar import debug + + logger = logging.getLogger(__name__) + debugger = debug(logger) + + + class myClass: + + def __init__(self): + logger.debug('Initializing class.') + self.a = 1 + self.b = 2 + + @debugger + def my_mthd(self): + logger.debug('Method in progress.') + return self.a + self.b + + if __name__ == "__main__": + logging.basicConfig() + logger.setLevel('DEBUG') + + obj = myClass() + result = obj.my_mthd() + logger.debug(f'result = {result}') + + Outputs, + + .. code-block:: text + + DEBUG:__main__:Entering myClass.__init__. + DEBUG:__main__:Initializing class. + DEBUG:__main__:Exiting myClass.__init__. + DEBUG:__main__:Entering myClass.my_mthd. + DEBUG:__main__:Method in progress. + DEBUG:__main__:Exiting myClass.my_mthd. + DEBUG:__main__:result = 3 + + """ + def _log(func): + @functools.wraps(func) + def wrap(*args, **kwargs): + flog = _function_logger(func, logger) + flog._entering_function(args, kwargs) + result = func(*args, **kwargs) + flog._exiting_function(result) + return result + return wrap + + return _log + + +class _formatter(logging.Formatter): + + def format(self, *args, **kwargs): + s = super(_formatter, self).format(*args, **kwargs) + lines = s.split('\n') + return ('\n' + ' '*INDENT).join(lines) + + +class _handler(logging.Handler): + + def __init__(self, level='NOTSET', **kwargs): + super().__init__(**kwargs) + fmt = _formatter(HANDLER_FMT) + self.setFormatter(fmt) + self.setLevel(level) + + +class _stream_handler(_handler, logging.StreamHandler): + + def __init__(self, level='INFO', **kwargs): + super(_stream_handler, self).__init__(level=level, **kwargs) + + +class _file_handler(_handler, logging.FileHandler): + + def __init__(self, filename, level='DEBUG', **kwargs): + super(_file_handler, self).__init__(filename=filename, level=level, **kwargs) + + +class log_file: + """ + Context manager for file logging. It logs everything under the ``loggername`` + logger, by default this is the ``'pbjam'`` logger (i.e. logs everything from + the pbjam package). + + Parameters + ---------- + filename : str + Filename to save the log + level : str, optional + Logging level. Default is 'DEBUG'. + loggername : str, optional + Name of logger which will send logs to ``filename``. Default is ``'pbjam'``. + + Attributes + ---------- + handler : pbjam.jar._file_handler + File handler object. + + Examples + -------- + .. code-block:: python + + from pbjam.jar import log_file + + with log_file('example.log') as flog: + # Do some pbjam stuff here and it will be logged to 'example.log' + ... + + # Do some stuff here and it won't be logged to 'example.log' + + with flog: + # Do some stuff here and it will be logged to 'example.log' + ... + + """ + def __init__(self, filename, level='DEBUG', loggername='pbjam'): + self._filename = filename + self._level = level + self._logger = logging.getLogger(loggername) + self.handler = None + self._isopen = False + + def open(self): + """ If log file is not open, creates a file handler at the log level """ + if not self._isopen: + self.handler = _file_handler(self._filename, level=self._level) + self._logger.addHandler(self.handler) + self._isopen = True + + def close(self): + """ If log file is open, safely closes the file handler """ + if self._isopen: + self._logger.removeHandler(self.handler) + self.handler.close() + self.handler = None + self._isopen = False + + def get_level(self): + return self._level + + def set_level(self, level): + """ + Set the level of the file handler. + + Parameters + ---------- + level : str + Choose from 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' or + 'NOTSET'. + """ + self._level = level + if self._isopen: + self.handler.setLevel(self._level) + + def __enter__(self): + self.open() + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class file_logger: + """ + Creates a ``log_file`` at ``filename`` to which logs under ``loggername`` at + a given ``level`` are recorded when the file logger is listening. This + class is indended to be sub-classed (see Examples). + + To listen to a method in a sub-class of ``file_logger`` (i.e. record all logs + which occur during the method execution) decorate the class method with + ``@file_logger.listen``. + + Parameters + ---------- + filename : str + Filename to save the log + level : str, optional + Logging level. Default is 'DEBUG'. + loggername : str, optional + Name of logger which will send logs to ``filename``. Default is ``'pbjam'``. + + Attributes + ---------- + log_file : pbjam.jar.log_file + + Examples + -------- + .. code-block:: python + + # pbjam/example.py + from .jar import file_logger + + class example_class(file_logger): + def __init__(self): + super(example_class, self).__init__('example.log', level='INFO') + + with self.log_file: + # Records content in context to log_file + logger.info('Initializing class.') + ... + + @file_logger.listen # records content of example_method to log_file + def example_method(self): + logger.info('Performing function tasks.') + ... + + """ + + def __init__(self, *args, **kwargs): + self.log_file = log_file(*args, **kwargs) + + @staticmethod + def listen(func): + """ + Decorator for recording logs to ``log_file`` during function operation, + closing the log file upon completion. + """ + @functools.wraps(func) + def wrap(self, *args, **kwargs): + self.log_file.open() + result = func(self, *args, **kwargs) + self.log_file.close() + return result + return wrap + + class references(): - """ A class for managing references used when running PBjam. + """ + A class for managing references used when running PBjam. This is inherited by session and star. @@ -62,12 +386,12 @@ def _findBlockEnd(self, string, idx): a += 1 if (i >= len(string[idx:])-1) and (a != 0): - print('Warning: Reached end of bibtex file with no closing curly bracket. Your .bib file may be formatted incorrectly. The reference list may be garbled.') + logger.warning('Reached end of bibtex file with no closing curly bracket. Your .bib file may be formatted incorrectly. The reference list may be garbled.') if a ==0: break if string[idx+i] == '{': - print('Warning: Ended on an opening bracket. Your .bib file may be formatted incorrectly.') + logger.warning('Ended on an opening bracket. Your .bib file may be formatted incorrectly.') return idx+i @@ -171,7 +495,8 @@ def get_priorpath(): def get_percentiles(X, nsigma = 2, **kwargs): - """ Get percentiles of an distribution + """ + Get percentiles of an distribution Compute the percentiles corresponding to sigma=1,2,3.. including the median (50th), of an array. @@ -201,7 +526,8 @@ def get_percentiles(X, nsigma = 2, **kwargs): def to_log10(x, xerr): - """ Transform to value to log10 + """ + Transform to value to log10 Takes a value and related uncertainty and converts them to logscale. Approximate. @@ -225,7 +551,8 @@ def to_log10(x, xerr): return [x, xerr] def normal(x, mu, sigma): - """ Evaluate logarithm of normal distribution (not normalized!!) + """ + Evaluate logarithm of normal distribution (not normalized!!) Evaluates the logarithm of a normal distribution at x. @@ -246,4 +573,4 @@ def normal(x, mu, sigma): if (sigma < 0): return 0.0 - return -0.5 * (x - mu)**2 / sigma**2 \ No newline at end of file + return -0.5 * (x - mu)**2 / sigma**2 diff --git a/pbjam/mcmc.py b/pbjam/mcmc.py index 6a7fcd6..5f6223d 100644 --- a/pbjam/mcmc.py +++ b/pbjam/mcmc.py @@ -10,7 +10,13 @@ import scipy.stats as st import cpnest.model import pandas as pd -import os +import os, logging + +from .jar import debug + +logger = logging.getLogger(__name__) +debugger = debug(logger) + class mcmc(): """ Class for MCMC sampling using `emcee' @@ -50,7 +56,7 @@ class mcmc(): Acceptance fraction at each step. """ - + # @debugger def __init__(self, start, likelihood, prior, nwalkers=50): self.start = start @@ -69,6 +75,9 @@ def __init__(self, start, likelihood, prior, nwalkers=50): self.flatlnlike = None self.acceptance = None + def __repr__(self): + return f'' + def logpost(self, p): """ Evaluate the likelihood and prior @@ -116,7 +125,7 @@ def stationarity(self, nfactor=20): converged = np.all(tau * nfactor < self.sampler.iteration) return converged - + @debugger def __call__(self, max_iter=20000, spread=1e-4, start_samples=[]): """ Initialize and run the EMCEE afine invariant sampler @@ -165,15 +174,16 @@ def __call__(self, max_iter=20000, spread=1e-4, start_samples=[]): pos, prob, state = self.sampler.run_mcmc(initial_state=pos, nsteps=nsteps) while not self.stationarity(): pos, prob, state = self.sampler.run_mcmc(initial_state=pos, nsteps=nsteps) - print(f'Steps taken: {self.sampler.iteration}') + logger.info(f'Steps taken: {self.sampler.iteration}') if self.sampler.iteration == max_iter: break if self.sampler.iteration < max_iter: - print(f'Chains reached stationary state after {self.sampler.iteration} iterations.') + logger.info(f'Chains reached stationary state after {self.sampler.iteration} iterations.') elif self.sampler.iteration == max_iter: - print(f'Sampler stopped at {max_iter} (maximum). Chains did not necessarily reach a stationary state.') + logger.warning(f'Sampler stopped at {max_iter} (maximum). Chains did not necessarily reach a stationary state.') else: - print('Unhandled exception') + # TODO: handle this exception + logger.critical('Unhandled exception') # Fold in low AR chains and run a little bit to update emcee self.fold(pos, spread=spread) @@ -199,7 +209,7 @@ def __call__(self, max_iter=20000, spread=1e-4, start_samples=[]): return self.flatchain - + # @debugger def fold(self, pos, accept_lim = 0.2, spread=0.1): """ Fold low acceptance walkers into main distribution @@ -265,7 +275,7 @@ class nested(cpnest.model.Model): Function that will return the log prior when called as prior(params) """ - + # @debugger def __init__(self, names, bounds, likelihood, prior, path): self.names=names self.bounds=bounds @@ -276,6 +286,8 @@ def __init__(self, names, bounds, likelihood, prior, path): if not os.path.isdir(self.path): os.mkdir(self.path) + def __repr__(self): + return f'' def log_likelihood(self, param): """ Wrapper for log likelihood """ @@ -286,6 +298,7 @@ def log_prior(self,p): if not self.in_bounds(p): return -np.inf return self.prior(p.values) + @debugger def __call__(self, nlive=100, nthreads=1, maxmcmc=100, poolsize=100): """ Runs the nested sampling @@ -316,4 +329,4 @@ def __call__(self, nlive=100, nthreads=1, maxmcmc=100, poolsize=100): self.samples = pd.DataFrame(self.nest.get_posterior_samples())[self.names] self.flatchain = self.samples.values self.acceptance = None - return self.samples \ No newline at end of file + return self.samples diff --git a/pbjam/peakbag.py b/pbjam/peakbag.py index 860a02f..fe246fc 100644 --- a/pbjam/peakbag.py +++ b/pbjam/peakbag.py @@ -7,8 +7,14 @@ import numpy as np import pymc3 as pm -import warnings +import arviz as az +import warnings, logging, inspect from .plotting import plotting +from .jar import debug + +logger = logging.getLogger(__name__) +debugger = debug(logger) + class peakbag(plotting): """ Class for the final peakbagging. @@ -62,7 +68,7 @@ class peakbag(plotting): See asy_peakbag asymptotic_fit for more details. """ - + # @debugger def __init__(self, starinst, init=True, path=None, verbose=False): self.pg = starinst.pg @@ -79,7 +85,10 @@ def __init__(self, starinst, init=True, path=None, verbose=False): starinst.peakbag = self + def __repr__(self): + return '' + @debugger def make_start(self): """ Set the starting model for peakbag @@ -111,6 +120,7 @@ def make_start(self): self.n = np.linspace(0.0, 1.0, len(self.start['l0']))[:, None] + @debugger def remove_outsiders(self, l0, l2): """ Drop outliers @@ -130,6 +140,7 @@ def remove_outsiders(self, l0, l2): sel = np.where(np.logical_and(l0 < self.f.max(), l0 > self.f.min())) return l0[sel], l2[sel] + @debugger def trim_ladder(self, lw_fac=10, extra=0.01, verbose=False): """ Turns mode frequencies into list of pairs @@ -156,18 +167,23 @@ def trim_ladder(self, lw_fac=10, extra=0.01, verbose=False): w = d02_lw + (extra * 10**self.asy_fit.summary.loc['dnu', 'mean']) bw = self.f[1] - self.f[0] w /= bw - if verbose: - print(f'w = {int(w)}') - print(f'bw = {bw}') + # if verbose: + # print(f'w = {int(w)}') + # print(f'bw = {bw}') + logger.debug(f'w = {int(w)}') + logger.debug(f'bw = {bw}') + ladder_trim_f = np.zeros([len(self.start['l0']), int(w)]) ladder_trim_s = np.zeros([len(self.start['l0']), int(w)]) for idx, freq in enumerate(self.start['l0']): loc_mid_02 = np.argmin(np.abs(self.f - (freq - d02/2.0))) if loc_mid_02 == 0: - warnings.warn('Did not find optimal pair location') - if verbose: - print(f'loc_mid_02 = {loc_mid_02}') - print(f'w/2 = {int(w/2)}') + logger.warning('Did not find optimal pair location') + # if verbose: + # print(f'loc_mid_02 = {loc_mid_02}') + # print(f'w/2 = {int(w/2)}') + logger.debug(f'loc_mid_02 = {loc_mid_02}') + logger.debug(f'w/2 = {int(w/2)}') ladder_trim_f[idx, :] = \ self.f[loc_mid_02 - int(w/2): loc_mid_02 - int(w/2) + int(w)] ladder_trim_s[idx, :] = \ @@ -236,6 +252,7 @@ def model(self, l0, l2, width0, width2, height0, height2, back): mod += self.lor(l2, width2, height2) return mod.T + @debugger def init_model(self, model_type): """ Initialize the pymc3 model for peakbag @@ -267,7 +284,7 @@ def init_model(self, model_type): if model_type != 'model_gp': if model_type != 'simple': # defaults to simple if bad input - warnings.warn('Model not defined - using simple model') + logger.warning('Model not defined - using simple model') width0 = pm.Lognormal('width0', mu=np.log(self.start['width0']), sigma=width_fac, shape=N) width2 = pm.Lognormal('width2', mu=np.log(self.start['width2']), @@ -277,7 +294,7 @@ def init_model(self, model_type): self.target_accept = 0.9 elif model_type == 'model_gp': - warnings.warn('This model is developmental - use carefully') + logger.warning('This model is developmental - use carefully') # Place a GP over the l=0 mode widths ... m0 = pm.Normal('gradient0', 0, 10) c0 = pm.Normal('intercept0', 0, 10) @@ -339,7 +356,7 @@ def _addPPRatio(self): self.summary.at[idx, 'log_ppr'] = log_ppr[idx] - + @debugger def __call__(self, model_type='simple', tune=1500, nthreads=1, maxiter=4, advi=False): """ Perform all the steps in peakbag. @@ -370,9 +387,8 @@ def __call__(self, model_type='simple', tune=1500, nthreads=1, maxiter=4, # REMOVE THIS WHEN pymc3 v3.8 is a bit older. try: rhatfunc = pm.diagnostics.gelman_rubin - warnings.warn('pymc3.diagnostics.gelman_rubin is depcrecated; upgrade pymc3 to v3.8 or newer.', DeprecationWarning) except: - rhatfunc = pm.stats.rhat + rhatfunc = az.rhat if advi: @@ -388,24 +404,36 @@ def __call__(self, model_type='simple', tune=1500, nthreads=1, maxiter=4, else: Rhat_max = 10 niter = 1 + + sample_kwargs = dict(tune=tune * niter, cores=nthreads, + start=self.start, + init=self.init_sampler, + target_accept=self.target_accept, + progressbar=False) + + # To surpress future warning - check back in future + if 'return_inferencedata' in inspect.getfullargspec(pm.sample).kwonlyargs: + sample_kwargs['return_inferencedata'] = False + while Rhat_max > 1.05: if niter > maxiter: - warnings.warn('Did not converge!') + logger.warning('Did not converge!') break + + sample_kwargs['tune'] = tune * niter + with self.pm_model: - self.traces = pm.sample(tune=tune * niter, cores=nthreads, - start=self.start, - init=self.init_sampler, - target_accept=self.target_accept, - progressbar=False) - Rhat_max = np.max([v.max() for k, v in rhatfunc(self.traces).items()]) - niter += 1 + self.traces = pm.sample(**sample_kwargs) + + Rhat_max = np.max([v.max() for k, v in rhatfunc(self.traces).items()]) + niter += 1 - # REMOVE THIS WHEN pymc3 v3.8 is a bit older - try: - self.summary = pm.summary(self.traces) - except: - self.summary = pm.stats.summary(self.traces) + with self.pm_model: + # REMOVE THIS WHEN pymc3 v3.8 is a bit older + try: + self.summary = pm.summary(self.traces) + except: + self.summary = az.summary(self.traces) self.par_names = self.summary.index diff --git a/pbjam/plotting.py b/pbjam/plotting.py index c98f3ac..b1a026b 100644 --- a/pbjam/plotting.py +++ b/pbjam/plotting.py @@ -13,7 +13,13 @@ import astropy.units as u import pandas as pd -class plotting(): +from .jar import debug + +logger = logging.getLogger(__name__) # For module-level logging +debugger = debug(logger) + + +class plotting: """ Class inherited by PBjam modules to plot results This is used to standardize the plots produced at various steps of the @@ -24,9 +30,9 @@ class plotting(): called from. """ - - def __init__(self): - pass + + def __init__(self, *args, **kwargs): + super(plotting, self).__init__(*args, **kwargs) def _save_my_fig(self, fig, figtype, path, ID): """ Save the figure object @@ -52,7 +58,8 @@ def _save_my_fig(self, fig, figtype, path, ID): if path and ID: outpath = os.path.join(*[path, type(self).__name__+f'_{figtype}_{str(ID)}.png']) fig.savefig(outpath) - + + @debugger def plot_echelle(self, pg=None, path=None, ID=None, savefig=False): """ Make echelle plot @@ -153,6 +160,7 @@ def plot_echelle(self, pg=None, path=None, ID=None, savefig=False): return fig + @debugger def plot_corner(self, path=None, ID=None, savefig=False): """ Make corner plot of result. @@ -176,7 +184,7 @@ def plot_corner(self, path=None, ID=None, savefig=False): """ if not hasattr(self, 'samples'): - warnings.warn(f"'{self.__class__.__name__}' has no attribute 'samples'. Can't plot a corner plot.") + logger.error(f"'{self.__class__.__name__}' has no attribute 'samples'. Can't plot a corner plot.") return None fig = corner.corner(self.samples, labels=self.par_names, @@ -188,6 +196,7 @@ def plot_corner(self, path=None, ID=None, savefig=False): return fig + @debugger def plot_spectrum(self, pg=None, path=None, ID=None, savefig=False): """ Plot the power spectrum @@ -421,8 +430,7 @@ def _make_prior_corner(self, df, numax_rng = 100): return crnr, crnr.get_axes() - - + @debugger def plot_prior(self, path=None, ID=None, savefig=False): """ Corner of result in relation to prior sample. @@ -474,6 +482,7 @@ def plot_prior(self, path=None, ID=None, savefig=False): return crnr + @debugger def plot_start(self): """ Plot starting point for peakbag diff --git a/pbjam/printer.py b/pbjam/printer.py new file mode 100644 index 0000000..e2cb4d4 --- /dev/null +++ b/pbjam/printer.py @@ -0,0 +1,37 @@ +import numpy as np +import pandas as pd +from pprint import PrettyPrinter + + +class pretty_printer(PrettyPrinter): + _dispatch = {} + + def _format_ndarray(self, object, stream, indent, allowance, context, level): + write = stream.write + max_width = self._width - indent - allowance + with np.printoptions(linewidth=max_width): + string = repr(object) + + lines = string.split('\n') + string = ('\n' + indent * ' ').join(lines) + write(string) + + def _pprint_ndarray(self, object, stream, indent, allowance, context, level): + self._format_ndarray(object, stream, indent, allowance, context, level) + + _dispatch[np.ndarray.__repr__] = _pprint_ndarray + + def _format_dataframe(self, object, stream, indent, allowance, context, level): + write = stream.write + max_width = self._width - indent - allowance + with pd.option_context('display.width', max_width, 'display.max_columns', None): + string = repr(object) + + lines = string.split('\n') + string = f'\n{indent*" "}'.join(lines) + write(string) + + def _pprint_dataframe(self, object, stream, indent, allowance, context, level): + self._format_dataframe(object, stream, indent, allowance, context, level) + + _dispatch[pd.DataFrame.__repr__] = _pprint_dataframe diff --git a/pbjam/priors.py b/pbjam/priors.py index 55a9a9a..100e511 100644 --- a/pbjam/priors.py +++ b/pbjam/priors.py @@ -12,7 +12,12 @@ import warnings from .plotting import plotting import statsmodels.api as sm -from .jar import get_priorpath, to_log10, normal +from .jar import get_priorpath, to_log10, normal, debug +import logging + +logger = logging.getLogger(__name__) +debugger = debug(logger) + class kde(plotting): """ A class to produce prior for asy_peakbag and initial starting location. @@ -52,7 +57,7 @@ class kde(plotting): to compute the KDE. Default is to use pbjam/data/prior_data.csv """ - + # @debugger def __init__(self, starinst=None, prior_file=None): if starinst: @@ -74,6 +79,10 @@ def __init__(self, starinst=None, prior_file=None): self.verbose = False + def __repr__(self): + return f'' + + @debugger def select_prior_data(self, numax=None, KDEsize = 100): """ Selects useful prior data based on proximity to estimated numax. @@ -149,9 +158,9 @@ def _prior_size_check(self, pdata, numax, KDEsize): idx = np.abs(pdata.numax.values - numax[0]) < nsigma * numax[1] if not flag_warn: - warnings.warn(f'Only {len(pdata[idx])} star(s) near provided numax. ' + - f'Trying to expand the range to include ~{KDEsize} stars.') - flag_warn = True + logger.warning(f'Only {len(pdata[idx])} star(s) near provided numax. ' + + f'Trying to expand the range to include ~{KDEsize} stars.') + flag_warn = True # So this message only appears once if nsigma >= KDEsize: break @@ -161,16 +170,18 @@ def _prior_size_check(self, pdata, numax, KDEsize): ntgts = len(idx[idx==1]) if ntgts == 0: - raise ValueError('No prior targets found within range of target. This might mean no prior samples exist for stars like this, consider increasing the uncertainty on your numax input.') + raise ValueError('No prior targets found within range of target. This might mean no prior samples exist' + \ + ' for stars like this, consider increasing the uncertainty on your numax input.') elif ntgts < KDEsize: - warnings.warn(f'Sample for estimating KDE is less than the requested {KDEsize}.') + msg = f'Sample size for estimating prior KDE is {ntgts}, less than the desired {KDEsize} - ' + \ + 'the prior may not comprise similar stars. If your uncertainty on numax is < 1 per cent, it may be too small.' + logger.warning(msg) KDEsize = ntgts return pdata.sample(KDEsize, weights=idx, replace=False) - - + @debugger def make_kde(self, bw_fac=1.0): """ Takes the prior data and constructs a KDE function @@ -203,19 +214,23 @@ def make_kde(self, bw_fac=1.0): self.select_prior_data(self._log_obs['numax']) - if self.verbose: - print(f'Selected data set length {len(self.prior_data)}') + # if self.verbose: + # print(f'Selected data set length {len(self.prior_data)}') + logger.debug(f'Selected prior dataset length: {len(self.prior_data)}') if bw_fac != 1: + logger.info('Selecting stars for KDE with user-specified bandwidth.') from statsmodels.nonparametric.bandwidths import select_bandwidth bw = select_bandwidth(self.prior_data[self.par_names].values, bw = 'scott', kernel=None) bw *= bw_fac else: - if self.verbose: - print('Selecting sensible stars for kde') - print(f'Full data set length {len(self.prior_data)}') + # if self.verbose: + # print('Selecting sensible stars for kde') + # print(f'Full data set length {len(self.prior_data)}') + + logger.info('Automatically selecting stars for KDE') bw = 'cv_ml' self.kde = sm.nonparametric.KDEMultivariate( @@ -223,7 +238,6 @@ def make_kde(self, bw_fac=1.0): var_type='c'*len(self.par_names), bw=bw) - def prior(self, p): """ Calculates the log prior for the initial guess fit. @@ -282,6 +296,7 @@ def likelihood(self, p): return lnlike + @debugger def kde_predict(self, n): """ Predict the l=0 mode frequencies from the KDE samples. @@ -314,7 +329,7 @@ def kde_predict(self, n): return freq.mean(axis=1), freq.std(axis=1) - + @debugger def kde_sampler(self, nwalkers=50): """ Samples the posterior distribution with the KDE prior @@ -340,8 +355,9 @@ def kde_sampler(self, nwalkers=50): """ - if self.verbose: - print('Running KDE sampler') + # if self.verbose: + # print('Running KDE sampler') + logger.info('Running KDE sampler') x0 = [self._log_obs['dnu'][0], # log10 dnu self._log_obs['numax'][0], # log10 numax diff --git a/pbjam/session.py b/pbjam/session.py index 0eed73c..1605336 100755 --- a/pbjam/session.py +++ b/pbjam/session.py @@ -49,12 +49,15 @@ import numpy as np import astropy.units as units import pandas as pd -import os, pickle, warnings +import os, pickle, warnings, logging from .star import star, _format_name from datetime import datetime -from .jar import references +from .jar import references, debug, file_logger +logger = logging.getLogger(__name__) +debugger = debug(logger) +@debugger def _organize_sess_dataframe(vardf): """ Takes input dataframe and tidies it up. @@ -90,7 +93,7 @@ def _organize_sess_dataframe(vardf): if 'spectrum' not in vardf.keys(): _format_col(vardf, None, 'spectrum') - +@debugger def _organize_sess_input(**vardct): """ Takes input and organizes them in a dataframe. @@ -131,6 +134,7 @@ def _organize_sess_input(**vardct): vardf[key+'_err'] = np.array(vardct[key]).reshape((-1, 2))[:, 1].flatten() return vardf +@debugger def _sort_lc(lc): """ Sort a lightcurve in Lightkurve object. @@ -154,6 +158,7 @@ def _sort_lc(lc): return lc +@debugger def _query_lightkurve(ID, download_dir, use_cached, lkwargs): """ Get time series using LightKurve @@ -197,7 +202,7 @@ def _query_lightkurve(ID, download_dir, use_cached, lkwargs): return lc - +@debugger def _arr_to_lk(x, y, name, typ): """ LightKurve object from input. @@ -231,7 +236,7 @@ def _arr_to_lk(x, y, name, typ): else: raise KeyError("Don't modify anything but spectrum and timeseries cols") - +@debugger def _format_col(vardf, col, key): """ Add timeseries or spectrum column to dataframe based on input @@ -298,8 +303,10 @@ def _format_col(vardf, col, key): np.array([_arr_to_lk(x, y, vardf.loc[i, 'ID'], key)])) vardf[key] = temp else: - print('Unhandled exception') + # TODO: handle this exception + logger.critical('Unhandled exception.') +@debugger def _lc_to_lk(ID, tsIn, specIn, download_dir, use_cached, lkwargs): """ Convert time series column in dataframe to lk.LightCurve object @@ -358,7 +365,7 @@ def _lc_to_lk(ID, tsIn, specIn, download_dir, use_cached, lkwargs): return tsOut - +@debugger def _lk_to_pg(ID, tsIn, specIn): """ Convert spectrum column in dataframe to Lightkurve.periodgram objects @@ -404,8 +411,7 @@ def _lk_to_pg(ID, tsIn, specIn): return specOut - -class session(): +class session(file_logger): """ Main class used to initiate peakbagging. Use this class to initialize a star class instance for one or more targets. @@ -518,7 +524,18 @@ class session(): download_dir : str, optional Directory to cache lightkurve downloads. Lightkurve will place the fits files in the default lightkurve cache path in your home directory. - + session_ID : str, optional + Session identifier. Default is ``'session'``. This is the name given to + the ``log_file`` for the session. Give this a unique name when running + multiple sessions with the same ``path``, otherwise logs will be appended + to the same file. + logging_level : str, optional + Level at which logs will be recorded to a log file called + '{session_ID}.log' at ``path``. Default is 'DEBUG' (recommended). Choose + from 'DEBUG', 'INFO', 'WARNING', 'ERROR' and 'CRITICAL'. All logs at + levels including and following ``logging_level`` will be recorded to the + file. + Attributes ---------- stars : list @@ -532,68 +549,92 @@ class session(): def __init__(self, ID=None, numax=None, dnu=None, teff=None, bp_rp=None, timeseries=None, spectrum=None, dictlike=None, use_cached=False, cadence=None, campaign=None, sector=None, month=None, - quarter=None, mission=None, path=None, download_dir=None): - - self.stars = [] - self.references = references() - self.references._addRef(['python', 'pandas', 'numpy', 'astropy', - 'lightkurve']) + quarter=None, mission=None, path=None, download_dir=None, + session_ID=None, logging_level='DEBUG'): + + self.session_ID = session_ID or 'session' + logfilename = os.path.join(path or os.getcwd(), f'{self.session_ID}.log') + super(session, self).__init__(filename=logfilename, level=logging_level, loggername='pbjam.session') - if isinstance(dictlike, (dict, np.recarray, pd.DataFrame, str)): - if isinstance(dictlike, str): - vardf = pd.read_csv(dictlike) + with self.log_file: + # Records everything in context to the log file + logger.info('Starting session.') + self.stars = [] + self.references = references() + self.references._addRef(['python', 'pandas', 'numpy', 'astropy', + 'lightkurve']) + + if isinstance(dictlike, (dict, np.recarray, pd.DataFrame, str)): + if isinstance(dictlike, str): + vardf = pd.read_csv(dictlike) + else: + try: + vardf = pd.DataFrame.from_records(dictlike) + except TypeError: + # TODO: Shouldn't this raise an exception? + logger.critical('Unrecognized type in dictlike. Must be able to convert to dataframe through pandas.DataFrame.from_records()') + + + if any([ID, numax, dnu, teff, bp_rp]): + logger.warning('Dictlike provided as input, ignoring other input fit parameters.') + + _organize_sess_dataframe(vardf) + + elif ID: + vardf = _organize_sess_input(ID=ID, numax=numax, dnu=dnu, teff=teff, + bp_rp=bp_rp, cadence=cadence, + campaign=campaign, sector=sector, + month=month, quarter=quarter, + mission=mission) + + _format_col(vardf, timeseries, 'timeseries') + _format_col(vardf, spectrum, 'spectrum') else: - try: - vardf = pd.DataFrame.from_records(dictlike) - except TypeError: - print('Unrecognized type in dictlike. Must be able to convert to dataframe through pandas.DataFrame.from_records()') - - if any([ID, numax, dnu, teff, bp_rp]): - warnings.warn('Dictlike provided as input, ignoring other input fit parameters.') - - _organize_sess_dataframe(vardf) - - elif ID: - vardf = _organize_sess_input(ID=ID, numax=numax, dnu=dnu, teff=teff, - bp_rp=bp_rp, cadence=cadence, - campaign=campaign, sector=sector, - month=month, quarter=quarter, - mission=mission) - - _format_col(vardf, timeseries, 'timeseries') - _format_col(vardf, spectrum, 'spectrum') - else: - raise TypeError('session.__init__ requires either ID or dictlike') + raise TypeError('session.__init__ requires either ID or dictlike') - for i in vardf.index: - - lkwargs = {x: vardf.loc[i, x] for x in ['cadence', 'month', - 'sector', 'campaign', - 'quarter', 'mission']} - - vardf.at[i, 'timeseries'] = _lc_to_lk(vardf.loc[i, 'ID'], - vardf.loc[i, 'timeseries'], - vardf.loc[i, 'spectrum'], - download_dir, - use_cached, - lkwargs) - - vardf.at[i,'spectrum'] = _lk_to_pg(vardf.loc[i,'ID'], - vardf.loc[i, 'timeseries'], - vardf.loc[i, 'spectrum']) - - self.stars.append(star(ID=vardf.loc[i, 'ID'], - pg=vardf.loc[i, 'spectrum'], - numax=vardf.loc[i, ['numax', 'numax_err']].values, - dnu=vardf.loc[i, ['dnu', 'dnu_err']].values, - teff=vardf.loc[i, ['teff', 'teff_err']].values, - bp_rp=vardf.loc[i, ['bp_rp', 'bp_rp_err']].values, - path=path)) - - for i, st in enumerate(self.stars): - if st.numax[0] > st.f[-1]: - warnings.warn("Input numax is greater than Nyquist frequeny for %s" % (st.ID)) + with pd.option_context( + 'display.max_rows', None, 'display.max_columns', None, + 'expand_frame_repr', False, 'max_colwidth', 15 + ): + logger.debug('Input DataFrame:\n' + str(vardf)) + for i in vardf.index: + + lkwargs = {x: vardf.loc[i, x] for x in ['cadence', 'month', + 'sector', 'campaign', + 'quarter', 'mission']} + + vardf.at[i, 'timeseries'] = _lc_to_lk(vardf.loc[i, 'ID'], + vardf.loc[i, 'timeseries'], + vardf.loc[i, 'spectrum'], + download_dir, + use_cached, + lkwargs) + + vardf.at[i,'spectrum'] = _lk_to_pg(vardf.loc[i,'ID'], + vardf.loc[i, 'timeseries'], + vardf.loc[i, 'spectrum']) + + logger.debug(f'Adding star with ID {repr(vardf.loc[i, "ID"])}') + self.stars.append(star(ID=vardf.loc[i, 'ID'], + pg=vardf.loc[i, 'spectrum'], + numax=vardf.loc[i, ['numax', 'numax_err']].values, + dnu=vardf.loc[i, ['dnu', 'dnu_err']].values, + teff=vardf.loc[i, ['teff', 'teff_err']].values, + bp_rp=vardf.loc[i, ['bp_rp', 'bp_rp_err']].values, + path=path)) + + for i, st in enumerate(self.stars): + if st.numax[0] > st.f[-1]: + # TODO: should this raise an exception? We know this will break later on. + logger.critical("Input numax is greater than Nyquist frequeny for %s" % (st.ID)) + + def __repr__(self): + """ Repr for the ``session`` class. """ + return f'' + + @file_logger.listen + @debugger def __call__(self, bw_fac=1, norders=8, model_type='simple', tune=1500, nthreads=1, verbose=False, make_plots=False, store_chains=False, asy_sampling='emcee', developer_mode=False): @@ -633,15 +674,14 @@ def __call__(self, bw_fac=1, norders=8, model_type='simple', tune=1500, the prior sample. Important: This is not good practice for getting science results! """ - self.pb_model_type = model_type for i, st in enumerate(self.stars): try: st(bw_fac=bw_fac, tune=tune, norders=norders, - model_type=self.pb_model_type, make_plots=make_plots, - store_chains=store_chains, nthreads=nthreads, - asy_sampling=asy_sampling, developer_mode=developer_mode) + model_type=self.pb_model_type, make_plots=make_plots, + store_chains=store_chains, nthreads=nthreads, + asy_sampling=asy_sampling, developer_mode=developer_mode) self.references._reflist += st.references._reflist @@ -649,10 +689,14 @@ def __call__(self, bw_fac=1, norders=8, model_type='simple', tune=1500, # Crude way to send error messages that occur in star up to Session # without ending the session. Is there a better way? + # Yes: logging using `logger.exception` - logs full traceback! except Exception as ex: - message = "Star {0} produced an exception of type {1} occurred. Arguments:\n{2!r}".format(st.ID, type(ex).__name__, ex.args) - print(message) - + # message = "Star {0} produced an exception of type {1} occurred. Arguments:\n{2!r}".format(st.ID, type(ex).__name__, ex.args) + # print(message) + logger.exception(f"{st} failed due to the following exception, continuing to the next star.") + + +@debugger def _load_fits(files, mission): """ Read fitsfiles into a Lightkurve object @@ -680,6 +724,7 @@ def _load_fits(files, mission): lc = lccol.PDCSAP_FLUX.stitch() return lc +@debugger def _set_mission(ID, lkwargs): """ Set mission keyword in lkwargs. @@ -705,7 +750,8 @@ def _set_mission(ID, lkwargs): lkwargs['mission'] = 'TESS' else: lkwargs['mission'] = ('Kepler', 'K2', 'TESS') - + +@debugger def _search_and_dump(ID, lkwargs, search_cache): """ Get lightkurve search result online. @@ -745,6 +791,7 @@ def _search_and_dump(ID, lkwargs, search_cache): return resultDict +@debugger def _getMASTidentifier(ID, lkwargs): """ return KIC/TIC/EPIC for given ID. @@ -789,6 +836,7 @@ def _getMASTidentifier(ID, lkwargs): ID = ID.replace(' ', '') return ID +@debugger def _perform_search(ID, lkwargs, use_cached=True, download_dir=None, cache_expire=30): """ Find filenames related to target @@ -844,6 +892,7 @@ def _perform_search(ID, lkwargs, use_cached=True, download_dir=None, return resultDict['result'] +@debugger def _check_lc_cache(search, mission, download_dir=None): """ Query cache directory or download fits files. @@ -886,6 +935,7 @@ def _check_lc_cache(search, mission, download_dir=None): return files_in_cache +@debugger def _clean_lc(lc): """ Perform Lightkurve operations on object. diff --git a/pbjam/star.py b/pbjam/star.py index ea88128..762cd8f 100644 --- a/pbjam/star.py +++ b/pbjam/star.py @@ -27,8 +27,14 @@ from astroquery.simbad import Simbad import astropy.units as units +import logging +from .jar import debug, file_logger -class star(plotting): +logger = logging.getLogger(__name__) # For module-level logging +debugger = debug(logger) + + +class star(plotting, file_logger): """ Class for each star to be peakbagged Additional attributes are added for each step of the peakbagging process @@ -68,6 +74,11 @@ class star(plotting): prior_file : str, optional Path to the csv file containing the prior data. Default is pbjam/data/prior_data.csv + logging_level : str, optional + Level at which logs will be recorded to a log file called '{ID}.log' + at ``path``. Default is 'DEBUG' (recommended). Choose from 'DEBUG', + 'INFO', 'WARNING', 'ERROR' and 'CRITICAL'. All logs at levels including + and following ``logging_level`` will be recorded to the file. Attributes ---------- @@ -79,37 +90,43 @@ class star(plotting): """ def __init__(self, ID, pg, numax, dnu, teff=[None,None], bp_rp=[None,None], - path=None, prior_file=None): - + path=None, prior_file=None, logging_level='DEBUG'): + self.ID = ID + self._set_outpath(path) + logfilename = os.path.join(self.path, f'{self.ID}.log') + super(star, self).__init__(filename=logfilename, level=logging_level) - if numax[0] < 25: - warnings.warn('The input numax is less than 25. The prior is not well defined here, so be careful with the result.') - self.numax = numax - self.dnu = dnu + with self.log_file: + logger.info(f"Initializing star with ID {repr(self.ID)}.") - self.references = references() - self.references._addRef(['numpy', 'python', 'lightkurve', 'astropy']) - - teff, bp_rp = self._checkTeffBpRp(teff, bp_rp) - self.teff = teff - self.bp_rp = bp_rp + if numax[0] < 25: + logger.warning('The input numax is less than 25. The prior is not well defined here, so be careful with the result.') + self.numax = numax + self.dnu = dnu - self.pg = pg.flatten() # in case user supplies unormalized spectrum - self.f = self.pg.frequency.value - self.s = self.pg.power.value + self.references = references() + self.references._addRef(['numpy', 'python', 'lightkurve', 'astropy']) + + teff, bp_rp = self._checkTeffBpRp(teff, bp_rp) + self.teff = teff + self.bp_rp = bp_rp - self._obs = {'dnu': self.dnu, 'numax': self.numax, 'teff': self.teff, - 'bp_rp': self.bp_rp} - self._log_obs = {x: to_log10(*self._obs[x]) for x in self._obs.keys() if x != 'bp_rp'} + self.pg = pg.flatten() # in case user supplies unormalized spectrum + self.f = self.pg.frequency.value + self.s = self.pg.power.value - self._set_outpath(path) + self._obs = {'dnu': self.dnu, 'numax': self.numax, 'teff': self.teff, + 'bp_rp': self.bp_rp} + self._log_obs = {x: to_log10(*self._obs[x]) for x in self._obs.keys() if x != 'bp_rp'} - if prior_file is None: - self.prior_file = get_priorpath() - else: - self.prior_file = prior_file - + if prior_file is None: + self.prior_file = get_priorpath() + else: + self.prior_file = prior_file + + def __repr__(self): + return f'' def _checkTeffBpRp(self, teff, bp_rp): """ Set the Teff and/or bp_rp values @@ -218,11 +235,11 @@ def _set_outpath(self, path): try: os.makedirs(self.path) except Exception as ex: - message = "Could not create directory for Star {0} because an exception of type {1} occurred. Arguments:\n{2!r}".format(self.ID, type(ex).__name__, ex.args) - print(message) - - + # message = "Could not create directory for Star {0} because an exception of type {1} occurred. Arguments:\n{2!r}".format(self.ID, type(ex).__name__, ex.args) + logger.exception(f"Could not create directory for star {self.ID}.") + @file_logger.listen + @debugger def run_kde(self, bw_fac=1.0, make_plots=False, store_chains=False): """ Run all steps involving KDE. @@ -242,7 +259,7 @@ def run_kde(self, bw_fac=1.0, make_plots=False, store_chains=False): """ - print('Starting KDE estimation') + logger.info('Starting KDE estimation') # Init kde(self) @@ -265,8 +282,9 @@ def run_kde(self, bw_fac=1.0, make_plots=False, store_chains=False): if store_chains: kde_samps = pd.DataFrame(self.kde.samples, columns=self.kde.par_names) kde_samps.to_csv(self._get_outpath(f'kde_chains_{self.ID}.csv'), index=False) - + @file_logger.listen + @debugger def run_asy_peakbag(self, norders, make_plots=False, store_chains=False, method='emcee', developer_mode=False): @@ -295,7 +313,7 @@ def run_asy_peakbag(self, norders, make_plots=False, """ - print('Starting asymptotic peakbagging') + logger.info('Starting asymptotic peakbagging') # Init asymptotic_fit(self, norders=norders) @@ -320,8 +338,9 @@ def run_asy_peakbag(self, norders, make_plots=False, if store_chains: asy_samps = pd.DataFrame(self.asy_fit.samples, columns=self.asy_fit.par_names) asy_samps.to_csv(self._get_outpath(f'asymptotic_fit_chains_{self.ID}.csv'), index=False) - + @file_logger.listen + @debugger def run_peakbag(self, model_type='simple', tune=1500, nthreads=1, make_plots=False, store_chains=False): """ Run all steps involving peakbag. @@ -345,7 +364,7 @@ def run_peakbag(self, model_type='simple', tune=1500, nthreads=1, """ - print('Starting peakbagging') + logger.info('Starting peakbagging') # Init peakbag(self, self.asy_fit) @@ -366,8 +385,8 @@ def run_peakbag(self, model_type='simple', tune=1500, nthreads=1, if store_chains: peakbag_samps = pd.DataFrame(self.peakbag.samples, columns=self.peakbag.par_names) peakbag_samps.to_csv(self._get_outpath(f'peakbag_chains_{self.ID}.csv'), index=False) - + @file_logger.listen def __call__(self, bw_fac=1.0, norders=8, model_type='simple', tune=1500, nthreads=1, make_plots=True, store_chains=False, asy_sampling='emcee', developer_mode=False): @@ -404,18 +423,19 @@ def __call__(self, bw_fac=1.0, norders=8, model_type='simple', tune=1500, the prior sample. Important: This is not good practice for getting science results! """ - self.run_kde(bw_fac=bw_fac, make_plots=make_plots, store_chains=store_chains) self.run_asy_peakbag(norders=norders, make_plots=make_plots, - store_chains=store_chains, method=asy_sampling, - developer_mode=developer_mode) + store_chains=store_chains, method=asy_sampling, + developer_mode=developer_mode) self.run_peakbag(model_type=model_type, tune=tune, nthreads=nthreads, - make_plots=make_plots, store_chains=store_chains) + make_plots=make_plots, store_chains=store_chains) self.references._addRef('pandas') + +@debugger def _querySimbad(ID): """ Query any ID at Simbad for Gaia DR2 source ID. @@ -442,12 +462,12 @@ def _querySimbad(ID): Gaia DR2 source ID. Returns None if no Gaia ID is found. """ - print('Querying Simbad for Gaia ID') + logger.debug('Querying Simbad for Gaia ID.') try: job = Simbad.query_objectids(ID) except: - print(f'Unable to resolve {ID} with Simbad') + logger.debug(f'Unable to resolve {ID} with Simbad.') return None for line in job['ID']: @@ -455,6 +475,7 @@ def _querySimbad(ID): return line.replace('Gaia DR2 ', '') return None +@debugger def _queryTIC(ID, radius = 20): """ Query TIC for bp-rp value @@ -481,7 +502,7 @@ def _queryTIC(ID, radius = 20): Gaia bp-rp value from the TIC. """ - print('Querying TIC for Gaia bp-rp values.') + logger.debug('Querying TIC for Gaia bp-rp values.') job = Catalogs.query_object(objectname=ID, catalog='TIC', objType='STAR', radius = radius*units.arcsec) @@ -491,6 +512,7 @@ def _queryTIC(ID, radius = 20): else: return None +@debugger def _queryMAST(ID): """ Query any ID at MAST @@ -512,7 +534,7 @@ def _queryMAST(ID): """ - print(f'Querying MAST for the {ID} coordinates.') + logger.debug(f'Querying MAST for the {ID} coordinates.') mastobs = AsqMastObsCl() try: @@ -520,6 +542,7 @@ def _queryMAST(ID): except: return None +@debugger def _queryGaia(ID=None, coords=None, radius=2): """ Query Gaia archive for bp-rp @@ -543,32 +566,35 @@ def _queryGaia(ID=None, coords=None, radius=2): ------- bp_rp : float Gaia bp-rp value of the requested target from the Gaia archive. - """ - + """ from astroquery.gaia import Gaia + logger.debug('Querying Gaia archive for bp-rp values.') if ID is not None: - print('Querying Gaia archive for bp-rp values by target ID.') + logger.info('Querying Gaia archive for bp-rp values by target ID.') adql_query = "select * from gaiadr2.gaia_source where source_id=%s" % (ID) try: job = Gaia.launch_job(adql_query).get_results() except: + logger.debug(f'Unable to query Gaia archive using ID={ID}.') return None return float(job['bp_rp'][0]) elif coords is not None: - print('Querying Gaia archive for bp-rp values by target coordinates.') + logger.info('Querying Gaia archive for bp-rp values by target coordinates.') ra = coords.ra.value dec = coords.dec.value adql_query = f"SELECT DISTANCE(POINT('ICRS', {ra}, {dec}), POINT('ICRS', ra, dec)) AS dist, * FROM gaiaedr3.gaia_source WHERE 1=CONTAINS( POINT('ICRS', {ra}, {dec}), CIRCLE('ICRS', ra, dec,{radius})) ORDER BY dist ASC" try: job = Gaia.launch_job(adql_query).get_results() except: + logger.debug('Unable to query Gaia archive using coords={coords}.') return None return float(job['bp_rp'][0]) else: raise ValueError('No ID or coordinates provided when querying the Gaia archive.') +@debugger def _format_name(ID): """ Format input ID @@ -613,6 +639,7 @@ def _format_name(ID): return fname return ID +@debugger def get_bp_rp(ID): """ Search online for bp_rp values based on ID. @@ -653,7 +680,9 @@ def get_bp_rp(ID): try: coords = _queryMAST(ID) bp_rp = _queryGaia(coords=coords) - except: - print(f'Unable to retrieve a bp_rp value for {ID}.') + except Exception as exc: + # Note that logger.exception gives the full Traceback or just set exc_info + logger.debug(f'Exception: {exc}.', exc_info=1) + logger.warning(f'Unable to retrieve a bp_rp value for {ID}.') bp_rp = np.nan - return bp_rp \ No newline at end of file + return bp_rp diff --git a/pbjam/tests/test_asy_peakbag.py b/pbjam/tests/test_asy_peakbag.py index 7aa8a02..9ac8b7f 100644 --- a/pbjam/tests/test_asy_peakbag.py +++ b/pbjam/tests/test_asy_peakbag.py @@ -321,7 +321,7 @@ def test_asymp_spec_model_call(): assert_allclose(mod(inp), mod.model(*inp)) def test_clean_up(): - + os.remove(cs.st.log_file._filename) # Removes log file os.rmdir(cs.st.path) # The test functions below require longer runs and are not suitable for GitHub diff --git a/pbjam/tests/test_jar.py b/pbjam/tests/test_jar.py index 227721b..847de2b 100644 --- a/pbjam/tests/test_jar.py +++ b/pbjam/tests/test_jar.py @@ -1,9 +1,12 @@ """Tests for the jar module""" -from pbjam.jar import normal, to_log10, get_priorpath, get_percentiles +from pbjam.jar import normal, to_log10, get_priorpath, get_percentiles, log_file, file_logger, debug import pbjam.tests.pbjam_tests as pbt import numpy as np from numpy.testing import assert_almost_equal, assert_array_equal +import logging, os + +logger = logging.getLogger('pbjam.tests') def test_normal(): """Test for the log of a normal distribution""" @@ -79,4 +82,106 @@ def test_get_percentiles(): inp = [[0,0,0,1,1], 1] assert_array_equal(func(*inp), [0., 0., 1.]) - \ No newline at end of file +def test_file_logger(): + """Tests subclassing ``jam`` to use the log file record decorator""" + test_message = 'This should be logged in file.' + + class file_logger_test(file_logger): + def __init__(self): + super(file_logger_test, self).__init__(filename='test_jam.log') + logger.debug('This should not be logged in file.') + with self.log_file: + # Records content in context to `log_file` + logger.debug(test_message) + + @file_logger.listen # records content of `example_method` to `log_file` + def method(self): + logger.debug(test_message) + + jt = file_logger_test() + jt.method() + + filename = jt.log_file._filename + with open(filename, 'r') as file_in: + lines = file_in.read().splitlines() + messages = [line.split('::')[-1].strip() for line in lines] + assert(all([message == test_message for message in messages])) + + os.remove(filename) + +def test_log_file(): + """Test ``file_logger`` context manager.""" + filename = 'test_file_logger.log' + test_level = 'DEBUG' + flog = log_file(filename, level=test_level) + + with flog: + test_message = 'This should be logged in file.' + logger.debug(test_message) + logger.debug('This should not be logged in file') + + with open(filename, 'r') as file_in: + lines = file_in.read().splitlines() + assert(len(lines) == 1) + + record = lines.pop().split('::') + level = record[1].strip() + assert(level == test_level) + + message = record[-1].strip() + assert(message == test_message) + + os.remove(filename) + +def test_debug_logger(): + """Tests ``log`` decorator debug messages""" + test_message = 'Function in progress.' + + @debug(logger) + def log_test(): + logger.debug(test_message) + + filename = 'test_log.log' + flog = log_file(filename) + + with flog: + log_test() + + with open(filename, 'r') as file_in: + lines = file_in.read().splitlines() + + messages = [line.split('::')[-1].strip() for line in lines] + + end = log_test.__qualname__ + assert(messages[0].startswith('Entering') and messages[0].endswith(end)) + assert(messages[-1].startswith('Exiting') and messages[-1].endswith(end)) + assert(test_message in messages) + + os.remove(filename) + +def test_debug_info(): + """Tests ``debug`` decorator with INFO level.""" + + test_message = 'Function in progress.' + + @debug(logger) + def log_test(): + logger.debug(test_message) + logger.info(test_message) + logger.warning(test_message) + logger.error(test_message) + logger.critical(test_message) + + filename = 'test_log.log' + flog = log_file(filename, level='INFO') # level='INFO' same as console_handler + + with flog: + log_test() + + with open(filename, 'r') as file_in: + lines = file_in.read().splitlines() + + levels = [line.split('::')[0].strip() for line in lines] + assert('DEBUG' not in levels) + + os.remove(filename) diff --git a/pbjam/tests/test_priors.py b/pbjam/tests/test_priors.py index f5c63ba..9cbe37d 100644 --- a/pbjam/tests/test_priors.py +++ b/pbjam/tests/test_priors.py @@ -55,10 +55,6 @@ def test_prior_size_check(): for sigma in [10, 100]: pdata_cut = func(pdata, to_log10(numax, sigma), KDEsize) assert((len(pdata_cut) > 0) & (len(pdata_cut) <= KDEsize)) - - # These combinations should show warnings - with pytest.warns(UserWarning): - func(pdata, to_log10(300, 1), 500) # These combinations should raise errors with pytest.raises(ValueError): diff --git a/pbjam/tests/test_star.py b/pbjam/tests/test_star.py index 25e0a7c..0ba5b07 100644 --- a/pbjam/tests/test_star.py +++ b/pbjam/tests/test_star.py @@ -41,6 +41,7 @@ def test_star_init(): pbt.assert_hasattributes(st, atts) # cleanup + os.remove(st.log_file._filename) # Remove log file os.rmdir(st.path) def test_outpath(): @@ -68,6 +69,7 @@ def test_outpath(): assert(os.path.isdir(os.path.dirname(func(*inp)))) # cleanup + os.remove(st.log_file._filename) os.rmdir(st.path) def test_set_outpath(): @@ -88,6 +90,7 @@ def test_set_outpath(): # Input tests and clean up assert(os.path.isdir(st.path)) + os.remove(st.log_file._filename) os.rmdir(st.path) inp = [pth] @@ -118,6 +121,7 @@ def test_run_kde(): pbt.does_it_run(func, None) # cleanup + os.remove(st.log_file._filename) os.rmdir(st.path) def test_format_name(): diff --git a/requirements.txt b/requirements.txt index 17bf663..be59dd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ hdbscan scikit-learn<=0.22.0 nbsphinx cpnest>=0.9.9 +arviz