Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 162 additions & 178 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,11 @@ max-complexity = 12
]

[tool.ruff.lint.pylint]
max-args = 5 # (PLR0913) Maximum number of arguments for function / method
max-bool-expr = 5 # ( PLR0916) Boolean in a single if statement
max-args=15 # (PLR0913) Maximum number of arguments for function / method
max-bool-expr=5 # ( PLR0916) Boolean in a single if statement
max-branches=12 # (PLR0912) branches allowed for a function or method body
max-locals=15 # (PLR0912) local variables allowed for a function or method body
max-nested-blocks = 5 # (PLR1702) nested blocks within a function or method body
max-nested-blocks=5 # (PLR1702) nested blocks within a function or method body
max-public-methods=20 # (R0904) public methods allowed for a class
max-returns=6 # (PLR0911) return statements for a function or method body
max-statements=50 # (PLR0915) statements allowed for a function or method body
Expand Down
10 changes: 6 additions & 4 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,8 @@ def generate(

return self._combine_fwp_output(hi_res, exogenous_data)

def _run_exo_layer(self, layer, input_array, hi_res_exo):
@classmethod
def _run_exo_layer(cls, layer, input_array, hi_res_exo):
"""Private run_exo_layer method used in ``_tf_generate``. Runs a layer
that combines exogenous data with the hi_res data. These layers can
include single or multiple exogenous features."""
Expand Down Expand Up @@ -1227,9 +1228,10 @@ def get_single_grad(
loss_details : dict
Namespace of the breakdown of loss components
"""
with tf.device(device_name), tf.GradientTape(
watch_accessed_variables=False
) as tape:
with (
tf.device(device_name),
tf.GradientTape(watch_accessed_variables=False) as tape,
):
tape.watch(training_weights)
loss, loss_details, _, _ = self._get_hr_exo_and_loss(
low_res, hi_res_true, **calc_loss_kwargs
Expand Down
38 changes: 25 additions & 13 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,8 @@ def train(
adaptive_update_bounds=(0.9, 0.99),
adaptive_update_fraction=0.0,
multi_gpu=False,
tensorboard_log=False,
tensorboard_profile=False,
log_tb=False,
export_tb=False,
):
"""Train the GAN model on real low res data and real high res data

Expand Down Expand Up @@ -703,12 +703,12 @@ def train(
rate that the model and optimizer were initialized with.
If true and multiple gpus are found, ``default_device`` device
should be set to /gpu:0
tensorboard_log : bool
log_tb : bool
Whether to write log file for use with tensorboard. Log data can
be viewed with ``tensorboard --logdir <logdir>`` where ``<logdir>``
is the parent directory of ``out_dir``, and pointing the browser to
the printed address.
tensorboard_profile : bool
export_tb : bool
Whether to export profiling information to tensorboard. This can
then be viewed in the tensorboard dashboard under the profile tab

Expand All @@ -720,10 +720,8 @@ def train(
(3) Would like an automatic way to exit the batch handler thread
instead of manually calling .stop() here.
"""
if tensorboard_log:
if log_tb:
self._init_tensorboard_writer(out_dir)
if tensorboard_profile:
self._write_tb_profile = True

self.set_norm_stats(batch_handler.means, batch_handler.stds)
params = self.check_batch_handler_attrs(batch_handler)
Expand Down Expand Up @@ -759,6 +757,7 @@ def train(
train_disc,
disc_loss_bounds,
multi_gpu=multi_gpu,
export_tb=export_tb,
)
loss_details.update(
self.calc_val_loss(batch_handler, weight_gen_advers)
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means):
disc_loss = self._train_record['train_loss_disc'].values.mean()
gen_loss = self._train_record['train_loss_gen'].values.mean()

logger.debug(
logger.info(
'Batch {} out of {} has (gen / disc) loss of: ({:.2e} / {:.2e}). '
'Running mean (gen / disc): ({:.2e} / {:.2e}). Trained '
'(gen / disc): ({} / {})'.format(
Expand Down Expand Up @@ -1102,6 +1101,7 @@ def _train_epoch(
train_disc,
disc_loss_bounds,
multi_gpu=False,
export_tb=False,
):
"""Train the GAN for one epoch.

Expand Down Expand Up @@ -1129,6 +1129,9 @@ def _train_epoch(
rate that the model and optimizer were initialized with.
If true and multiple gpus are found, ``default_device`` device
should be set to /gpu:0
export_tb : bool
Whether to export profiling information to tensorboard. This can
then be viewed in the tensorboard dashboard under the profile tab

Returns
-------
Expand All @@ -1151,9 +1154,10 @@ def _train_epoch(
only_gen = train_gen and not train_disc
only_disc = train_disc and not train_gen

if self._write_tb_profile:
if export_tb:
tf.summary.trace_on(graph=True, profiler=True)

prev_time = time.time()
for ib, batch in enumerate(batch_handler):
start = time.time()

Expand All @@ -1163,7 +1167,7 @@ def _train_epoch(
disc_too_bad = (loss_disc > disc_th_high) and train_disc
gen_too_good = disc_too_bad

b_loss_details = self.timer(self._train_batch, log=True)(
b_loss_details = self._train_batch(
batch,
train_gen,
only_gen,
Expand All @@ -1175,17 +1179,25 @@ def _train_epoch(
multi_gpu,
)

loss_means = self.timer(self._post_batch, log=True)(
loss_means = self._post_batch(
ib, b_loss_details, len(batch_handler), loss_means
)

total_step_time = time.time() - prev_time
batch_step_time = time.time() - start
batch_load_time = total_step_time - batch_step_time

logger.info(
f'Finished batch step {ib + 1} / {len(batch_handler)} in '
f'{time.time() - start:.4f} seconds'
f'{total_step_time:.4f} seconds. Batch load time: '
f'{batch_load_time:.4f} seconds. Batch train time: '
f'{batch_step_time:.4f} seconds.'
)

prev_time = time.time()

self.total_batches += len(batch_handler)
loss_details = self._train_record.mean().to_dict()
loss_details['total_batches'] = int(self.total_batches)
self.profile_to_tensorboard('training_epoch')
self.profile_to_tensorboard('training_epoch', export_tb=export_tb)
return loss_details
7 changes: 4 additions & 3 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class TensorboardMixIn:
def __init__(self):
self._tb_writer = None
self._tb_log_dir = None
self._write_tb_profile = False
self._total_batches = None
self._history = None
self.timer = Timer()
Expand Down Expand Up @@ -116,15 +115,17 @@ def dict_to_tensorboard(self, entry):
else:
tf.summary.scalar(name, value, self.total_batches)

def profile_to_tensorboard(self, name):
def profile_to_tensorboard(self, name, export_tb=True):
"""Write profile data to tensorboard log file.

Parameters
----------
name : str
Tag name to use for profile info
export_tb : bool
Flag to enable/disable tensorboard profiling
"""
if self._tb_writer is not None and self._write_tb_profile:
if self._tb_writer is not None and export_tb:
with self._tb_writer.as_default():
tf.summary.trace_export(
name=name,
Expand Down
12 changes: 8 additions & 4 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
composite_info,
is_type_of,
)
from sup3r.utilities.utilities import Timer

logger = logging.getLogger(__name__)

Expand All @@ -44,7 +45,7 @@ class Sup3rMeta(ABCMeta, type):
kwargs as ``*args`` / ``**kwargs`` or those built through factory
composition, for example."""

def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804
def __new__(mcs, name, bases, namespace, **kwargs):
"""Define __name__ and __signature__"""
sig, doc = _get_class_info(namespace)
name = namespace.get('__name__', name)
Expand Down Expand Up @@ -201,11 +202,13 @@ def __getattr__(self, attr):
out = out[0]
return out

def _getattr(self, dset, attr):
@classmethod
def _getattr(cls, dset, attr):
"""Get attribute from single data member."""
return getattr(dset.sx, attr, getattr(dset, attr))

def _getitem(self, dset, item):
@classmethod
def _getitem(cls, dset, item):
"""Get item from single data member."""
return dset.sx[item] if hasattr(dset, 'sx') else dset[item]

Expand Down Expand Up @@ -318,7 +321,7 @@ class Container(metaclass=Sup3rMeta):
``Sup3rX`` objects (:class:`.Sup3rDataset`), or a tuple of such objects.
"""

__slots__ = ['_data']
__slots__ = ['_data', 'timer']

def __init__(
self,
Expand Down Expand Up @@ -353,6 +356,7 @@ def __init__(
:class:`~.samplers.DualSampler`, and a 1-tuple otherwise.
"""
self.data = data
self.timer = Timer()

@property
def data(self):
Expand Down
4 changes: 4 additions & 0 deletions sup3r/preprocessing/batch_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
sample_shape=sample_shape,
feature_sets=feature_sets,
batch_size=batch_size,
mode=mode,
sampler_kwargs=sampler_kwargs,
)

Expand Down Expand Up @@ -260,6 +261,7 @@ def init_samplers(
sample_shape,
feature_sets,
batch_size,
mode,
sampler_kwargs,
):
"""Initialize samplers from given data containers."""
Expand All @@ -269,6 +271,7 @@ def init_samplers(
sample_shape=sample_shape,
feature_sets=feature_sets,
batch_size=batch_size,
mode=mode,
**sampler_kwargs,
)
for container in train_containers
Expand All @@ -282,6 +285,7 @@ def init_samplers(
sample_shape=sample_shape,
feature_sets=feature_sets,
batch_size=batch_size,
mode=mode,
**sampler_kwargs,
)
for container in val_containers
Expand Down
4 changes: 0 additions & 4 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def preflight(self):
)
assert sampler_bs == self.batch_size, msg

if self.mode == 'eager':
logger.info('Received mode = "eager".')
_ = [c.compute() for c in self.containers]

@property
def queue_thread(self):
"""Get new queue thread."""
Expand Down
32 changes: 29 additions & 3 deletions sup3r/preprocessing/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
sample_shape: Optional[tuple] = None,
batch_size: int = 16,
feature_sets: Optional[dict] = None,
mode: str = 'lazy',
):
"""
Parameters
Expand Down Expand Up @@ -65,12 +66,18 @@ def __init__(
in the high-resolution observation but not expected to be
output from the generative model. An example is high-res
topography that is to be injected mid-network.
mode : str
Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode
pre-loads all data into memory as numpy arrays for faster access.
'lazy' mode samples directly from the underlying data object, which
could be backed by dask arrays or on-disk netCDF files.
"""
super().__init__(data=data)
feature_sets = feature_sets or {}
self.features = feature_sets.get('features', self.data.features)
self._lr_only_features = feature_sets.get('lr_only_features', [])
self._hr_exo_features = feature_sets.get('hr_exo_features', [])
self.mode = mode
self.sample_shape = sample_shape or (10, 10, 1)
self.batch_size = batch_size
self.lr_features = self.features
Expand Down Expand Up @@ -133,6 +140,9 @@ def preflight(self):
if self.data.shape[2] < self.sample_shape[2] * self.batch_size:
logger.warning(msg)
warn(msg)
if self.mode == 'eager':
logger.info('Received mode = "eager".')
_ = self.compute()

@property
def sample_shape(self) -> tuple:
Expand Down Expand Up @@ -197,7 +207,8 @@ def _reshape_samples(self, samples):
# (batch_size, lats, lons, times, feats)
return np.transpose(out, axes=(2, 0, 1, 3, 4))

def _stack_samples(self, samples):
@classmethod
def _stack_samples(cls, samples):
"""Used to build batch arrays in the case of independent time samples
(e.g. slow batching)

Expand Down Expand Up @@ -225,10 +236,25 @@ def _stack_samples(self, samples):
return (lr, hr)
return np.stack(samples, axis=0)

def _compute_samples(self, samples):
"""Cast samples to numpy arrays. This only does something when samples
are dask arrays.

Parameters
----------
samples : tuple[np.ndarray | da.core.Array, ...] |
np.ndarray | da.core.Array
Samples retrieved from the underlying data. Could be a tuple
in the case of dual datasets.
"""
if self.mode == 'eager':
return samples
return compute_if_dask(samples)

def _fast_batch(self):
"""Get batch of samples with adjacent time slices."""
out = self.data.sample(self.get_sample_index(n_obs=self.batch_size))
out = compute_if_dask(out)
out = self._compute_samples(out)
if isinstance(out, tuple):
return tuple(self._reshape_samples(o) for o in out)
return self._reshape_samples(out)
Expand All @@ -239,7 +265,7 @@ def _slow_batch(self):
self.data.sample(self.get_sample_index(n_obs=1))
for _ in range(self.batch_size)
]
out = compute_if_dask(out)
out = self._compute_samples(out)
return self._stack_samples(out)

def _fast_batch_possible(self):
Expand Down
Loading