@@ -60,8 +60,9 @@ def seasonality_to_float(seasonality: str, freq: str) -> float:
6060
6161
6262def seasonalities_to_array (
63- seasonalities : Sequence [float | str ], freq : str
64- ) -> np .ndarray :
63+ seasonalities : Sequence [float | str ],
64+ freq : str
65+ ) -> np .ndarray :
6566 """Convert a list of floats or strings to durations relative to a frequency.
6667
6768 Args:
@@ -95,10 +96,15 @@ def seasonalities_to_array(
9596
9697
9798def _convert_datetime_col (table , time_column , timetype , freq , time_min = None ):
99+ """Converts a time column in place according to the frequency."""
98100 if timetype == 'index' :
99101 first_date = pd .to_datetime ('2020-01-01' ).to_period (freq )
100102 table [time_column ] = table [time_column ].dt .to_period (freq )
101103 table [time_column ] = (table [time_column ] - first_date ).apply (lambda x : x .n )
104+ elif timetype == 'float' :
105+ table [time_column ] = table [time_column ].apply (float )
106+ else :
107+ raise ValueError (f'Unknown timetype: { timetype } ' )
102108 if time_min is None :
103109 time_min = table [time_column ].min ()
104110 table [time_column ] = table [time_column ] - time_min
@@ -217,7 +223,7 @@ def __init__(
217223 num_seasonal_harmonics : Sequence [int ] | None = None ,
218224 fourier_degrees : Sequence [float ] | None = None ,
219225 interactions : Sequence [tuple [int , int ]] | None = None ,
220- freq : str ,
226+ freq : str | None = None ,
221227 timetype : str = 'index' ,
222228 depth : int = 2 ,
223229 width : int = 512 ,
@@ -227,63 +233,41 @@ def __init__(
227233 """Shared initialization for subclasses of BayesianNeuralFieldEstimator.
228234
229235 Args:
230- feature_cols:
231- Names of columns to use as features in the training
232- data frame. The first entry denotes the name of the time variable,
233- the remaining entries (if any) denote names of the spatial features.
234-
235- target_col:
236- Name of the target column representing the spatial field.
237-
238- seasonality_periods:
239- A list of numbers representing the seasonal frequencies of the data
240- in the time domain. It is also possible to specify a string such as
241- 'W', 'D', etc. corresponding to a valid Pandas frequency: see the
242- Pandas [Offset Aliases](
243- https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
244- for valid values.
245-
246- num_seasonal_harmonics:
247- A list of seasonal harmonics, one for each entry in
248- `seasonality_periods`. The number of seasonal harmonics (h) for a
249- given seasonal period `p` must satisfy `h < p//2`.
250-
251- fourier_degrees:
252- A list of integer degrees for the Fourier features of the inputs.
253- If given, must have the same length as `feature_cols`.
254-
255- interactions:
256- A list of tuples of column indexes for the first-order
236+ feature_cols: Names of columns to use as features in the training data
237+ frame. The first entry denotes the name of the time variable, the
238+ remaining entries (if any) denote names of the spatial features.
239+ target_col: Name of the target column representing the spatial field.
240+ seasonality_periods: A list of numbers representing the seasonal
241+ frequencies of the data in the time domain. If timetype == 'index', then
242+ it is possible to specify numeric frequencies by using string short
243+ hands such as 'W', 'D', etc., which correspond to a valid Pandas
244+ frequency. See Pandas [Offset
245+ Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
246+ for valid string values.
247+ num_seasonal_harmonics: A list of seasonal harmonics, one for each entry
248+ in `seasonality_periods`. The number of seasonal harmonics (h) for a
249+ given seasonal period `p` must satisfy `h < p//2`. It is an error fir
250+ `len(num_seasonal_harmonics) != len(seasonality_periods)`. Should be
251+ used only if `timetype == 'index'`.
252+ fourier_degrees: A list of integer degrees for the Fourier features of the
253+ inputs. If given, must have the same length as `feature_cols`.
254+ interactions: A list of tuples of column indexes for the first-order
257255 interactions. For example `[(0,1), (1,2)]` creates two interaction
258- features
259-
260- - `feature_cols[0] * feature_cols[1]`
261- - `feature_cols[1] * feature_cols[2]`
262-
263- freq:
264- A frequency string for the sampling rate at which the data is
265- collected. See the Pandas
266- [Offset Aliases](
267- https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
268- for valid values.
269-
270- timetype:
271- Must be specified as `index`. The general versions will be
272- integrated pending https://github.com/google/bayesnf/issues/16.
273-
274- depth:
275- The number of hidden layers in the BayesNF architecture.
276-
277- width:
278- The number of hidden units in each layer.
279-
280- observation_model:
281- The aleatoric noise model for the observed data. The options are
282- `NORMAL` (Gaussian noise), `NB` (negative binomial noise), or `ZNB`
283- (zero-inflated negative binomial noise).
284-
285- standardize:
286- List of columns that should be standardized. It is highly
256+ features - `feature_cols[0] * feature_cols[1]` - `feature_cols[1] *
257+ feature_cols[2]`
258+ freq: A frequency string for the sampling rate at which the data is
259+ collected. See the Pandas [Offset
260+ Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
261+ for valid values. Should be used if and only if `timetype == 'index'`.
262+ timetype: Either `index` or `float`. If `index`, then the time column must
263+ be a `datetime` type and `freq` must be given. Otherwise, if `float`,
264+ then the time column must be `float`.
265+ depth: The number of hidden layers in the BayesNF architecture.
266+ width: The number of hidden units in each layer.
267+ observation_model: The aleatoric noise model for the observed data. The
268+ options are `NORMAL` (Gaussian noise), `NB` (negative binomial noise),
269+ or `ZNB` (zero-inflated negative binomial noise).
270+ standardize: List of columns that should be standardized. It is highly
287271 recommended to standardize `feature_cols[1:]`. It is an error if
288272 `features_cols[0]` (the time variable) is in `standardize`.
289273 """
@@ -337,18 +321,48 @@ def _get_interactions(self) -> np.ndarray:
337321 f' passed shape was { interactions .shape } )' )
338322 return interactions
339323
324+ def _get_seasonality_periods (self ):
325+ """Return array of seasonal periods."""
326+ if (
327+ (self .timetype == 'index' and self .freq is None ) or
328+ (self .timetype == 'float' and self .freq is not None )):
329+ raise ValueError (f'Invalid { self .freq = } with { self .timetype = } .' )
330+ if self .seasonality_periods is None :
331+ return np .zeros (0 )
332+ if self .timetype == 'index' :
333+ return seasonalities_to_array (self .seasonality_periods , self .freq )
334+ if self .timetype == 'float' :
335+ return np .asarray (self .seasonality_periods , dtype = float )
336+ assert False , f'Impossible { self .timetype = } .'
337+
338+ def _get_num_seasonal_harmonics (self ):
339+ """Return array of seasonal harmonics per seasonal period."""
340+ # Discrete time.
341+ if self .timetype == 'index' :
342+ return (
343+ np .array (self .num_seasonal_harmonics )
344+ if self .num_seasonal_harmonics is not None else
345+ np .zeros (0 ))
346+ # Continuous time.
347+ if self .timetype == 'float' :
348+ if self .num_seasonal_harmonics is not None :
349+ raise ValueError (
350+ f'Cannot use num_seasonal_harmonics with { self .timetype = } .'
351+ )
352+ # HACK: models.make_seasonal_frequencies assumes the data is discrete
353+ # time where each harmonic h is between 1, ..., p/2 and the harmonic
354+ # factors are np.arange(1, h + 1). Since our goal with continuous
355+ # time data is exactly 1 harmonic per seasonal factor, any h between
356+ # 0 and min(0.5, p/2) will work, as np.arange(1, 1+h) = [1]
357+ return np .fmin (.5 , self ._get_seasonality_periods () / 2 )
358+ assert False , f'Impossible { self .timetype = } .'
359+
340360 def _model_args (self , batch_shape ):
341361 return {
342362 'depth' : self .depth ,
343363 'input_scales' : self .data_handler .get_input_scales (),
344- 'num_seasonal_harmonics' :
345- np .array (self .num_seasonal_harmonics )
346- if self .num_seasonal_harmonics is not None
347- else np .zeros (0 ),
348- 'seasonality_periods' :
349- seasonalities_to_array (self .seasonality_periods , self .freq )
350- if self .seasonality_periods is not None
351- else np .zeros (0 ),
364+ 'num_seasonal_harmonics' : self ._get_num_seasonal_harmonics (),
365+ 'seasonality_periods' : self ._get_seasonality_periods (),
352366 'width' : self .width ,
353367 'init_x' : batch_shape ,
354368 'fourier_degrees' : self ._get_fourier_degrees (batch_shape ),
0 commit comments