Skip to content

Commit cbf2eb5

Browse files
authored
Merge pull request #295 from NatLabRockies/bnb/mode_to_samplers
add mode kwarg to samplers.
2 parents ba95d5d + 7b284e9 commit cbf2eb5

File tree

15 files changed

+269
-216
lines changed

15 files changed

+269
-216
lines changed

pixi.lock

Lines changed: 162 additions & 178 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,11 @@ max-complexity = 12
257257
]
258258

259259
[tool.ruff.lint.pylint]
260-
max-args = 5 # (PLR0913) Maximum number of arguments for function / method
261-
max-bool-expr = 5 # ( PLR0916) Boolean in a single if statement
260+
max-args=15 # (PLR0913) Maximum number of arguments for function / method
261+
max-bool-expr=5 # ( PLR0916) Boolean in a single if statement
262262
max-branches=12 # (PLR0912) branches allowed for a function or method body
263263
max-locals=15 # (PLR0912) local variables allowed for a function or method body
264-
max-nested-blocks = 5 # (PLR1702) nested blocks within a function or method body
264+
max-nested-blocks=5 # (PLR1702) nested blocks within a function or method body
265265
max-public-methods=20 # (R0904) public methods allowed for a class
266266
max-returns=6 # (PLR0911) return statements for a function or method body
267267
max-statements=50 # (PLR0915) statements allowed for a function or method body

sup3r/models/abstract.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,8 @@ def generate(
11041104

11051105
return self._combine_fwp_output(hi_res, exogenous_data)
11061106

1107-
def _run_exo_layer(self, layer, input_array, hi_res_exo):
1107+
@classmethod
1108+
def _run_exo_layer(cls, layer, input_array, hi_res_exo):
11081109
"""Private run_exo_layer method used in ``_tf_generate``. Runs a layer
11091110
that combines exogenous data with the hi_res data. These layers can
11101111
include single or multiple exogenous features."""
@@ -1227,9 +1228,10 @@ def get_single_grad(
12271228
loss_details : dict
12281229
Namespace of the breakdown of loss components
12291230
"""
1230-
with tf.device(device_name), tf.GradientTape(
1231-
watch_accessed_variables=False
1232-
) as tape:
1231+
with (
1232+
tf.device(device_name),
1233+
tf.GradientTape(watch_accessed_variables=False) as tape,
1234+
):
12331235
tape.watch(training_weights)
12341236
loss, loss_details, _, _ = self._get_hr_exo_and_loss(
12351237
low_res, hi_res_true, **calc_loss_kwargs

sup3r/models/base.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,8 @@ def train(
638638
adaptive_update_bounds=(0.9, 0.99),
639639
adaptive_update_fraction=0.0,
640640
multi_gpu=False,
641-
tensorboard_log=False,
642-
tensorboard_profile=False,
641+
log_tb=False,
642+
export_tb=False,
643643
):
644644
"""Train the GAN model on real low res data and real high res data
645645
@@ -703,12 +703,12 @@ def train(
703703
rate that the model and optimizer were initialized with.
704704
If true and multiple gpus are found, ``default_device`` device
705705
should be set to /gpu:0
706-
tensorboard_log : bool
706+
log_tb : bool
707707
Whether to write log file for use with tensorboard. Log data can
708708
be viewed with ``tensorboard --logdir <logdir>`` where ``<logdir>``
709709
is the parent directory of ``out_dir``, and pointing the browser to
710710
the printed address.
711-
tensorboard_profile : bool
711+
export_tb : bool
712712
Whether to export profiling information to tensorboard. This can
713713
then be viewed in the tensorboard dashboard under the profile tab
714714
@@ -720,10 +720,8 @@ def train(
720720
(3) Would like an automatic way to exit the batch handler thread
721721
instead of manually calling .stop() here.
722722
"""
723-
if tensorboard_log:
723+
if log_tb:
724724
self._init_tensorboard_writer(out_dir)
725-
if tensorboard_profile:
726-
self._write_tb_profile = True
727725

728726
self.set_norm_stats(batch_handler.means, batch_handler.stds)
729727
params = self.check_batch_handler_attrs(batch_handler)
@@ -759,6 +757,7 @@ def train(
759757
train_disc,
760758
disc_loss_bounds,
761759
multi_gpu=multi_gpu,
760+
export_tb=export_tb,
762761
)
763762
loss_details.update(
764763
self.calc_val_loss(batch_handler, weight_gen_advers)
@@ -1071,7 +1070,7 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means):
10711070
disc_loss = self._train_record['train_loss_disc'].values.mean()
10721071
gen_loss = self._train_record['train_loss_gen'].values.mean()
10731072

1074-
logger.debug(
1073+
logger.info(
10751074
'Batch {} out of {} has (gen / disc) loss of: ({:.2e} / {:.2e}). '
10761075
'Running mean (gen / disc): ({:.2e} / {:.2e}). Trained '
10771076
'(gen / disc): ({} / {})'.format(
@@ -1102,6 +1101,7 @@ def _train_epoch(
11021101
train_disc,
11031102
disc_loss_bounds,
11041103
multi_gpu=False,
1104+
export_tb=False,
11051105
):
11061106
"""Train the GAN for one epoch.
11071107
@@ -1129,6 +1129,9 @@ def _train_epoch(
11291129
rate that the model and optimizer were initialized with.
11301130
If true and multiple gpus are found, ``default_device`` device
11311131
should be set to /gpu:0
1132+
export_tb : bool
1133+
Whether to export profiling information to tensorboard. This can
1134+
then be viewed in the tensorboard dashboard under the profile tab
11321135
11331136
Returns
11341137
-------
@@ -1151,9 +1154,10 @@ def _train_epoch(
11511154
only_gen = train_gen and not train_disc
11521155
only_disc = train_disc and not train_gen
11531156

1154-
if self._write_tb_profile:
1157+
if export_tb:
11551158
tf.summary.trace_on(graph=True, profiler=True)
11561159

1160+
prev_time = time.time()
11571161
for ib, batch in enumerate(batch_handler):
11581162
start = time.time()
11591163

@@ -1163,7 +1167,7 @@ def _train_epoch(
11631167
disc_too_bad = (loss_disc > disc_th_high) and train_disc
11641168
gen_too_good = disc_too_bad
11651169

1166-
b_loss_details = self.timer(self._train_batch, log=True)(
1170+
b_loss_details = self._train_batch(
11671171
batch,
11681172
train_gen,
11691173
only_gen,
@@ -1175,17 +1179,25 @@ def _train_epoch(
11751179
multi_gpu,
11761180
)
11771181

1178-
loss_means = self.timer(self._post_batch, log=True)(
1182+
loss_means = self._post_batch(
11791183
ib, b_loss_details, len(batch_handler), loss_means
11801184
)
11811185

1186+
total_step_time = time.time() - prev_time
1187+
batch_step_time = time.time() - start
1188+
batch_load_time = total_step_time - batch_step_time
1189+
11821190
logger.info(
11831191
f'Finished batch step {ib + 1} / {len(batch_handler)} in '
1184-
f'{time.time() - start:.4f} seconds'
1192+
f'{total_step_time:.4f} seconds. Batch load time: '
1193+
f'{batch_load_time:.4f} seconds. Batch train time: '
1194+
f'{batch_step_time:.4f} seconds.'
11851195
)
11861196

1197+
prev_time = time.time()
1198+
11871199
self.total_batches += len(batch_handler)
11881200
loss_details = self._train_record.mean().to_dict()
11891201
loss_details['total_batches'] = int(self.total_batches)
1190-
self.profile_to_tensorboard('training_epoch')
1202+
self.profile_to_tensorboard('training_epoch', export_tb=export_tb)
11911203
return loss_details

sup3r/models/utilities.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class TensorboardMixIn:
8080
def __init__(self):
8181
self._tb_writer = None
8282
self._tb_log_dir = None
83-
self._write_tb_profile = False
8483
self._total_batches = None
8584
self._history = None
8685
self.timer = Timer()
@@ -116,15 +115,17 @@ def dict_to_tensorboard(self, entry):
116115
else:
117116
tf.summary.scalar(name, value, self.total_batches)
118117

119-
def profile_to_tensorboard(self, name):
118+
def profile_to_tensorboard(self, name, export_tb=True):
120119
"""Write profile data to tensorboard log file.
121120
122121
Parameters
123122
----------
124123
name : str
125124
Tag name to use for profile info
125+
export_tb : bool
126+
Flag to enable/disable tensorboard profiling
126127
"""
127-
if self._tb_writer is not None and self._write_tb_profile:
128+
if self._tb_writer is not None and export_tb:
128129
with self._tb_writer.as_default():
129130
tf.summary.trace_export(
130131
name=name,

sup3r/preprocessing/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
composite_info,
2525
is_type_of,
2626
)
27+
from sup3r.utilities.utilities import Timer
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -44,7 +45,7 @@ class Sup3rMeta(ABCMeta, type):
4445
kwargs as ``*args`` / ``**kwargs`` or those built through factory
4546
composition, for example."""
4647

47-
def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804
48+
def __new__(mcs, name, bases, namespace, **kwargs):
4849
"""Define __name__ and __signature__"""
4950
sig, doc = _get_class_info(namespace)
5051
name = namespace.get('__name__', name)
@@ -201,11 +202,13 @@ def __getattr__(self, attr):
201202
out = out[0]
202203
return out
203204

204-
def _getattr(self, dset, attr):
205+
@classmethod
206+
def _getattr(cls, dset, attr):
205207
"""Get attribute from single data member."""
206208
return getattr(dset.sx, attr, getattr(dset, attr))
207209

208-
def _getitem(self, dset, item):
210+
@classmethod
211+
def _getitem(cls, dset, item):
209212
"""Get item from single data member."""
210213
return dset.sx[item] if hasattr(dset, 'sx') else dset[item]
211214

@@ -318,7 +321,7 @@ class Container(metaclass=Sup3rMeta):
318321
``Sup3rX`` objects (:class:`.Sup3rDataset`), or a tuple of such objects.
319322
"""
320323

321-
__slots__ = ['_data']
324+
__slots__ = ['_data', 'timer']
322325

323326
def __init__(
324327
self,
@@ -353,6 +356,7 @@ def __init__(
353356
:class:`~.samplers.DualSampler`, and a 1-tuple otherwise.
354357
"""
355358
self.data = data
359+
self.timer = Timer()
356360

357361
@property
358362
def data(self):

sup3r/preprocessing/batch_handlers/factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(
211211
sample_shape=sample_shape,
212212
feature_sets=feature_sets,
213213
batch_size=batch_size,
214+
mode=mode,
214215
sampler_kwargs=sampler_kwargs,
215216
)
216217

@@ -260,6 +261,7 @@ def init_samplers(
260261
sample_shape,
261262
feature_sets,
262263
batch_size,
264+
mode,
263265
sampler_kwargs,
264266
):
265267
"""Initialize samplers from given data containers."""
@@ -269,6 +271,7 @@ def init_samplers(
269271
sample_shape=sample_shape,
270272
feature_sets=feature_sets,
271273
batch_size=batch_size,
274+
mode=mode,
272275
**sampler_kwargs,
273276
)
274277
for container in train_containers
@@ -282,6 +285,7 @@ def init_samplers(
282285
sample_shape=sample_shape,
283286
feature_sets=feature_sets,
284287
batch_size=batch_size,
288+
mode=mode,
285289
**sampler_kwargs,
286290
)
287291
for container in val_containers

sup3r/preprocessing/batch_queues/abstract.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ def preflight(self):
153153
)
154154
assert sampler_bs == self.batch_size, msg
155155

156-
if self.mode == 'eager':
157-
logger.info('Received mode = "eager".')
158-
_ = [c.compute() for c in self.containers]
159-
160156
@property
161157
def queue_thread(self):
162158
"""Get new queue thread."""

sup3r/preprocessing/samplers/base.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
sample_shape: Optional[tuple] = None,
3030
batch_size: int = 16,
3131
feature_sets: Optional[dict] = None,
32+
mode: str = 'lazy',
3233
):
3334
"""
3435
Parameters
@@ -65,12 +66,18 @@ def __init__(
6566
in the high-resolution observation but not expected to be
6667
output from the generative model. An example is high-res
6768
topography that is to be injected mid-network.
69+
mode : str
70+
Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode
71+
pre-loads all data into memory as numpy arrays for faster access.
72+
'lazy' mode samples directly from the underlying data object, which
73+
could be backed by dask arrays or on-disk netCDF files.
6874
"""
6975
super().__init__(data=data)
7076
feature_sets = feature_sets or {}
7177
self.features = feature_sets.get('features', self.data.features)
7278
self._lr_only_features = feature_sets.get('lr_only_features', [])
7379
self._hr_exo_features = feature_sets.get('hr_exo_features', [])
80+
self.mode = mode
7481
self.sample_shape = sample_shape or (10, 10, 1)
7582
self.batch_size = batch_size
7683
self.lr_features = self.features
@@ -133,6 +140,9 @@ def preflight(self):
133140
if self.data.shape[2] < self.sample_shape[2] * self.batch_size:
134141
logger.warning(msg)
135142
warn(msg)
143+
if self.mode == 'eager':
144+
logger.info('Received mode = "eager".')
145+
_ = self.compute()
136146

137147
@property
138148
def sample_shape(self) -> tuple:
@@ -197,7 +207,8 @@ def _reshape_samples(self, samples):
197207
# (batch_size, lats, lons, times, feats)
198208
return np.transpose(out, axes=(2, 0, 1, 3, 4))
199209

200-
def _stack_samples(self, samples):
210+
@classmethod
211+
def _stack_samples(cls, samples):
201212
"""Used to build batch arrays in the case of independent time samples
202213
(e.g. slow batching)
203214
@@ -225,10 +236,25 @@ def _stack_samples(self, samples):
225236
return (lr, hr)
226237
return np.stack(samples, axis=0)
227238

239+
def _compute_samples(self, samples):
240+
"""Cast samples to numpy arrays. This only does something when samples
241+
are dask arrays.
242+
243+
Parameters
244+
----------
245+
samples : tuple[np.ndarray | da.core.Array, ...] |
246+
np.ndarray | da.core.Array
247+
Samples retrieved from the underlying data. Could be a tuple
248+
in the case of dual datasets.
249+
"""
250+
if self.mode == 'eager':
251+
return samples
252+
return compute_if_dask(samples)
253+
228254
def _fast_batch(self):
229255
"""Get batch of samples with adjacent time slices."""
230256
out = self.data.sample(self.get_sample_index(n_obs=self.batch_size))
231-
out = compute_if_dask(out)
257+
out = self._compute_samples(out)
232258
if isinstance(out, tuple):
233259
return tuple(self._reshape_samples(o) for o in out)
234260
return self._reshape_samples(out)
@@ -239,7 +265,7 @@ def _slow_batch(self):
239265
self.data.sample(self.get_sample_index(n_obs=1))
240266
for _ in range(self.batch_size)
241267
]
242-
out = compute_if_dask(out)
268+
out = self._compute_samples(out)
243269
return self._stack_samples(out)
244270

245271
def _fast_batch_possible(self):

0 commit comments

Comments
 (0)