Skip to content

Commit af16f80

Browse files
docs: add docstrings to predict function
1 parent 788e066 commit af16f80

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/predict.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
1+
@doc raw"""
2+
predict(rc, steps::Integer, ps, st; initialdata=nothing)
3+
predict(rc, data::AbstractMatrix, ps, st)
4+
5+
Run the model either in (1) closed-loop (auto-regressive) mode for a fixed number
6+
of steps, or in (2) teacher-forced (point-by-point) mode over a given input
7+
sequence.
8+
9+
## 1) Auto-regressive rollout
10+
11+
**Behavior**
12+
- Rolls the model forward for `steps` time steps.
13+
- At each step, the model’s output becomes the next input.
14+
15+
### Arguments
16+
- `rc`: The reservoir chain / model.
17+
- `steps::Integer`: Number of time steps to generate.
18+
- `ps`: Model parameters (from `setup` or after `train!`).
19+
- `st`: Model state (carry, RNG replicas, etc.), threaded across time.
20+
21+
### Keyword Arguments
22+
- `initialdata=nothing`: Column vector used as the first input.
23+
24+
### Returns
25+
26+
- `Y`: Generated outputs of shape `(out_dims, steps)`.
27+
- `st_out`: Final model state after `steps` steps.
28+
29+
30+
## 2) Teacher-forced / point-by-point
31+
32+
- Feeds each column of `data` as input; the model state is threaded across time,
33+
and an output is produced for each input column.
34+
35+
### Arguments
36+
37+
- `rc`: The reservoir chain / model.
38+
- `data`: Input sequence of shape `(in_dims, T)` (columns are time).
39+
- `ps`: Model parameters.
40+
- `st`: Initial model state before processing `data`.
41+
42+
### Returns
43+
44+
- `Y::AbstractMatrix`: Outputs for each input column, shape `(out_dims, T)`.
45+
- `st_out`: Final model state after consuming all `T` columns.
46+
"""
147
function predict(rc::AbstractLuxLayer, steps::Int, ps, st; initialdata=nothing)
248
if initialdata == nothing
349
initialdata = rand(Float32, 3)

0 commit comments

Comments
 (0)