Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit f7a50eb

Browse files
author
The bayesnf Authors
committed
Merge pull request #38 from google:20240226-fsaad-float-time
PiperOrigin-RevId: 612850710
2 parents 0eae3f9 + f3c6d0d commit f7a50eb

10 files changed

+1467
-146
lines changed

src/bayesnf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# A new PyPI release will be pushed every time `__version__` is increased.
1818
# When changing this, also update the CHANGELOG.md
19-
__version__ = '0.1.2'
19+
__version__ = '0.1.3'
2020

2121
from .spatiotemporal import BayesianNeuralFieldMAP
2222
from .spatiotemporal import BayesianNeuralFieldMLE

src/bayesnf/spatiotemporal.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def seasonality_to_float(seasonality: str, freq: str) -> float:
6060

6161

6262
def seasonalities_to_array(
63-
seasonalities: Sequence[float | str], freq: str
64-
) -> np.ndarray:
63+
seasonalities: Sequence[float | str],
64+
freq: str
65+
) -> np.ndarray:
6566
"""Convert a list of floats or strings to durations relative to a frequency.
6667
6768
Args:
@@ -95,10 +96,15 @@ def seasonalities_to_array(
9596

9697

9798
def _convert_datetime_col(table, time_column, timetype, freq, time_min=None):
99+
"""Converts a time column in place according to the frequency."""
98100
if timetype == 'index':
99101
first_date = pd.to_datetime('2020-01-01').to_period(freq)
100102
table[time_column] = table[time_column].dt.to_period(freq)
101103
table[time_column] = (table[time_column] - first_date).apply(lambda x: x.n)
104+
elif timetype == 'float':
105+
table[time_column] = table[time_column].apply(float)
106+
else:
107+
raise ValueError(f'Unknown timetype: {timetype}')
102108
if time_min is None:
103109
time_min = table[time_column].min()
104110
table[time_column] = table[time_column] - time_min
@@ -217,7 +223,7 @@ def __init__(
217223
num_seasonal_harmonics: Sequence[int] | None = None,
218224
fourier_degrees: Sequence[float] | None = None,
219225
interactions: Sequence[tuple[int, int]] | None = None,
220-
freq: str,
226+
freq: str | None = None,
221227
timetype: str = 'index',
222228
depth: int = 2,
223229
width: int = 512,
@@ -227,63 +233,41 @@ def __init__(
227233
"""Shared initialization for subclasses of BayesianNeuralFieldEstimator.
228234
229235
Args:
230-
feature_cols:
231-
Names of columns to use as features in the training
232-
data frame. The first entry denotes the name of the time variable,
233-
the remaining entries (if any) denote names of the spatial features.
234-
235-
target_col:
236-
Name of the target column representing the spatial field.
237-
238-
seasonality_periods:
239-
A list of numbers representing the seasonal frequencies of the data
240-
in the time domain. It is also possible to specify a string such as
241-
'W', 'D', etc. corresponding to a valid Pandas frequency: see the
242-
Pandas [Offset Aliases](
243-
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
244-
for valid values.
245-
246-
num_seasonal_harmonics:
247-
A list of seasonal harmonics, one for each entry in
248-
`seasonality_periods`. The number of seasonal harmonics (h) for a
249-
given seasonal period `p` must satisfy `h < p//2`.
250-
251-
fourier_degrees:
252-
A list of integer degrees for the Fourier features of the inputs.
253-
If given, must have the same length as `feature_cols`.
254-
255-
interactions:
256-
A list of tuples of column indexes for the first-order
236+
feature_cols: Names of columns to use as features in the training data
237+
frame. The first entry denotes the name of the time variable, the
238+
remaining entries (if any) denote names of the spatial features.
239+
target_col: Name of the target column representing the spatial field.
240+
seasonality_periods: A list of numbers representing the seasonal
241+
frequencies of the data in the time domain. If timetype == 'index', then
242+
it is possible to specify numeric frequencies by using string short
243+
hands such as 'W', 'D', etc., which correspond to a valid Pandas
244+
frequency. See Pandas [Offset
245+
Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
246+
for valid string values.
247+
num_seasonal_harmonics: A list of seasonal harmonics, one for each entry
248+
in `seasonality_periods`. The number of seasonal harmonics (h) for a
249+
given seasonal period `p` must satisfy `h < p//2`. It is an error fir
250+
`len(num_seasonal_harmonics) != len(seasonality_periods)`. Should be
251+
used only if `timetype == 'index'`.
252+
fourier_degrees: A list of integer degrees for the Fourier features of the
253+
inputs. If given, must have the same length as `feature_cols`.
254+
interactions: A list of tuples of column indexes for the first-order
257255
interactions. For example `[(0,1), (1,2)]` creates two interaction
258-
features
259-
260-
- `feature_cols[0] * feature_cols[1]`
261-
- `feature_cols[1] * feature_cols[2]`
262-
263-
freq:
264-
A frequency string for the sampling rate at which the data is
265-
collected. See the Pandas
266-
[Offset Aliases](
267-
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
268-
for valid values.
269-
270-
timetype:
271-
Must be specified as `index`. The general versions will be
272-
integrated pending https://github.com/google/bayesnf/issues/16.
273-
274-
depth:
275-
The number of hidden layers in the BayesNF architecture.
276-
277-
width:
278-
The number of hidden units in each layer.
279-
280-
observation_model:
281-
The aleatoric noise model for the observed data. The options are
282-
`NORMAL` (Gaussian noise), `NB` (negative binomial noise), or `ZNB`
283-
(zero-inflated negative binomial noise).
284-
285-
standardize:
286-
List of columns that should be standardized. It is highly
256+
features - `feature_cols[0] * feature_cols[1]` - `feature_cols[1] *
257+
feature_cols[2]`
258+
freq: A frequency string for the sampling rate at which the data is
259+
collected. See the Pandas [Offset
260+
Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
261+
for valid values. Should be used if and only if `timetype == 'index'`.
262+
timetype: Either `index` or `float`. If `index`, then the time column must
263+
be a `datetime` type and `freq` must be given. Otherwise, if `float`,
264+
then the time column must be `float`.
265+
depth: The number of hidden layers in the BayesNF architecture.
266+
width: The number of hidden units in each layer.
267+
observation_model: The aleatoric noise model for the observed data. The
268+
options are `NORMAL` (Gaussian noise), `NB` (negative binomial noise),
269+
or `ZNB` (zero-inflated negative binomial noise).
270+
standardize: List of columns that should be standardized. It is highly
287271
recommended to standardize `feature_cols[1:]`. It is an error if
288272
`features_cols[0]` (the time variable) is in `standardize`.
289273
"""
@@ -337,18 +321,48 @@ def _get_interactions(self) -> np.ndarray:
337321
f' passed shape was {interactions.shape})')
338322
return interactions
339323

324+
def _get_seasonality_periods(self):
325+
"""Return array of seasonal periods."""
326+
if (
327+
(self.timetype == 'index' and self.freq is None) or
328+
(self.timetype == 'float' and self.freq is not None)):
329+
raise ValueError(f'Invalid {self.freq=} with {self.timetype=}.')
330+
if self.seasonality_periods is None:
331+
return np.zeros(0)
332+
if self.timetype == 'index':
333+
return seasonalities_to_array(self.seasonality_periods, self.freq)
334+
if self.timetype == 'float':
335+
return np.asarray(self.seasonality_periods, dtype=float)
336+
assert False, f'Impossible {self.timetype=}.'
337+
338+
def _get_num_seasonal_harmonics(self):
339+
"""Return array of seasonal harmonics per seasonal period."""
340+
# Discrete time.
341+
if self.timetype == 'index':
342+
return (
343+
np.array(self.num_seasonal_harmonics)
344+
if self.num_seasonal_harmonics is not None else
345+
np.zeros(0))
346+
# Continuous time.
347+
if self.timetype == 'float':
348+
if self.num_seasonal_harmonics is not None:
349+
raise ValueError(
350+
f'Cannot use num_seasonal_harmonics with {self.timetype=}.'
351+
)
352+
# HACK: models.make_seasonal_frequencies assumes the data is discrete
353+
# time where each harmonic h is between 1, ..., p/2 and the harmonic
354+
# factors are np.arange(1, h + 1). Since our goal with continuous
355+
# time data is exactly 1 harmonic per seasonal factor, any h between
356+
# 0 and min(0.5, p/2) will work, as np.arange(1, 1+h) = [1]
357+
return np.fmin(.5, self._get_seasonality_periods() / 2)
358+
assert False, f'Impossible {self.timetype=}.'
359+
340360
def _model_args(self, batch_shape):
341361
return {
342362
'depth': self.depth,
343363
'input_scales': self.data_handler.get_input_scales(),
344-
'num_seasonal_harmonics':
345-
np.array(self.num_seasonal_harmonics)
346-
if self.num_seasonal_harmonics is not None
347-
else np.zeros(0),
348-
'seasonality_periods':
349-
seasonalities_to_array(self.seasonality_periods, self.freq)
350-
if self.seasonality_periods is not None
351-
else np.zeros(0),
364+
'num_seasonal_harmonics': self._get_num_seasonal_harmonics(),
365+
'seasonality_periods': self._get_seasonality_periods(),
352366
'width': self.width,
353367
'init_x': batch_shape,
354368
'fourier_degrees': self._get_fourier_degrees(batch_shape),

tests/spatiotemporal_test.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)