Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_tcn/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import List
from collections.abc import Iterable

class InternalBuffer():
# Parent class that requires that there is a buffer attribute
pass

class BufferIO():
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_tcn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import math

from .buffer import BufferIO
from .buffer import BufferIO, InternalBuffer

from typing import Optional
from typing import Union
Expand All @@ -18,7 +18,7 @@
'circular',
]

class TemporalPad1d(nn.Module):
class TemporalPad1d(nn.Module, InternalBuffer):
def __init__(
self,
padding: int,
Expand Down
162 changes: 162 additions & 0 deletions pytorch_tcn/rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import os
import warnings
from typing import Optional

from .buffer import InternalBuffer
import torch
import torch.nn as nn

class TemporalGRU(nn.GRU, InternalBuffer):
"""
A custom RNN layer that supports streaming inference with internal buffering,
following a similar paradigm to TemporalConv1d.

The TemporalGRU behaves exactly like a standard nn.RNN during training.
When inference is enabled (inference=True), it uses an internal hidden-state
buffer (or one provided via a BufferIO object) to maintain continuity across
sequential calls (e.g. for streaming applications). The updated hidden state
is stored internally (or via the provided buffer_io) so that subsequent calls
will use the previous state.

Args:
input_size (int): The number of expected features in the input.
hidden_size (int): The number of features in the hidden state.
num_layers (int, optional): Number of recurrent layers. Default: 1.
bias (bool, optional): If False, then the layer does not use bias weights.
Default: True.
batch_first (bool, optional): If True, then the input and output tensors are
provided as (batch, seq, feature). Default: False.
dropout (float, optional): If non-zero, introduces a Dropout layer on the outputs
of each RNN layer except the last layer.
bidirectional (bool, optional): If True, becomes a bidirectional RNN.
buffer (torch.Tensor, optional): Initial hidden state buffer. Typically, this is
left as None so that the RNN will use a zero
initial state.
**kwargs: Additional keyword arguments passed to nn.RNN.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: bool = True,
batch_first: bool = False,
dropout: float = 0.0,
bidirectional: bool = False,
buffer: Optional[torch.Tensor] = None,
**kwargs,
):
super(TemporalGRU, self).__init__(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional,
**kwargs,
)
# Hidden state buffer for streaming inference.
# Its shape should be (num_layers * num_directions, batch, hidden_size).
# Buffer is used for streaming inference
if buffer is None:
self._buffer = torch.zeros(
num_layers * (2 if bidirectional else 1),
1,
hidden_size,)

elif isinstance(buffer, (int, float)):
self._buffer = torch.full(
size = (num_layers * (2 if bidirectional else 1), 1, hidden_size),
fill_value = buffer,
)
elif not isinstance(buffer, torch.Tensor):
raise ValueError(
f"""
The argument 'buffer' must be None or of type float,
int, or torch.Tensor, but got {type(buffer)}.
"""
)

self._buffer = buffer

@property
def buffer(self) -> torch.Tensor:
"""
Returns the hidden state buffer.
"""
return self._buffer

@buffer.setter
def buffer(self, value: torch.Tensor) -> None:
"""
Sets the hidden state buffer.
"""
self._buffer = value

def forward(
self,
x: torch.Tensor,
inference: bool = False,
in_buffer: torch.Tensor = None,
buffer_io: Optional["BufferIO"] = None,
) -> torch.Tensor:
"""
Forward pass of the TemporalGRU.

Args:
x (torch.Tensor): Input tensor.
inference (bool, optional): If True, streaming inference is enabled
and an internal hidden state buffer is used.
Default: False.
in_buffer (torch.Tensor): This argument has been deprecated.
Use buffer_io instead.
buffer_io (Optional[BufferIO]): An object that holds a 'buffer' attribute.
This can be used to externally manage the hidden state.

Returns:
torch.Tensor: The output sequence of the RNN.
"""
if in_buffer is not None:
raise ValueError(
"""
The argument 'in_buffer' was removed.
Instead, you should pass the hidden state buffer as a BufferIO object
to the argument 'buffer_io'.
"""
)

if not inference:
# In training mode, perform a standard forward pass.
output, _ = super(TemporalGRU, self).forward(x)
return output
else:
# In inference (streaming) mode, use the stored hidden state buffer.
if buffer_io is None:
h0 = self._buffer
else:
h0 = buffer_io.next_in_buffer()
if h0 is None:
h0 = self._buffer
buffer_io.append_internal_buffer(h0)

# If no buffer is present, default to None so that nn.RNN uses zero initial state.
if h0 is None:
output, hidden = super(TemporalGRU, self).forward(x)
else:
output, hidden = super(TemporalGRU, self).forward(x, h0)

# Update the buffer with the new hidden state.
if buffer_io is None:
self._buffer = hidden
else:
buffer_io.append_out_buffer( hidden )


return output

def reset_buffer(self) -> None:
"""
Resets the internal hidden state buffer to None.
"""
self._buffer = None
10 changes: 5 additions & 5 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from typing import Union
from typing import Optional
from collections.abc import Iterable
from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d
from pytorch_tcn.conv import TemporalConv1d
from pytorch_tcn.pad import TemporalPad1d
from pytorch_tcn.buffer import BufferIO
from pytorch_tcn.buffer import BufferIO, InternalBuffer


activation_fn = dict(
Expand Down Expand Up @@ -173,7 +173,7 @@ def _init_weights(m):

def reset_buffers(self):
def _reset_buffer(x):
if isinstance(x, (TemporalPad1d,) ):
if isinstance(x, (InternalBuffer,) ):
x.reset_buffer()
self.apply(_reset_buffer)
return
Expand All @@ -184,7 +184,7 @@ def get_buffers(self):
"""
buffers = []
def _get_buffers(x):
if isinstance(x, (TemporalPad1d,) ):
if isinstance(x, (InternalBuffer,) ):
buffers.append(x.buffer)
self.apply(_get_buffers)
return buffers
Expand Down Expand Up @@ -215,7 +215,7 @@ def set_buffers(self, buffers):
Set all buffers of the network in the order they were created.
"""
def _set_buffers(x):
if isinstance(x, (TemporalPad1d,) ):
if isinstance(x, (InternalBuffer,) ):
x.buffer = buffers.pop(0)
self.apply(_set_buffers)
return
Expand Down
Loading
Loading