diff --git a/pytorch_tcn/buffer.py b/pytorch_tcn/buffer.py index c74d0d2..86086a6 100644 --- a/pytorch_tcn/buffer.py +++ b/pytorch_tcn/buffer.py @@ -6,6 +6,9 @@ from typing import List from collections.abc import Iterable +class InternalBuffer(): + # Parent class that requires that there is a buffer attribute + pass class BufferIO(): def __init__( diff --git a/pytorch_tcn/pad.py b/pytorch_tcn/pad.py index e673768..b28c65a 100644 --- a/pytorch_tcn/pad.py +++ b/pytorch_tcn/pad.py @@ -4,7 +4,7 @@ import torch.nn as nn import math -from .buffer import BufferIO +from .buffer import BufferIO, InternalBuffer from typing import Optional from typing import Union @@ -18,7 +18,7 @@ 'circular', ] -class TemporalPad1d(nn.Module): +class TemporalPad1d(nn.Module, InternalBuffer): def __init__( self, padding: int, diff --git a/pytorch_tcn/rnn.py b/pytorch_tcn/rnn.py new file mode 100644 index 0000000..f7bbc5e --- /dev/null +++ b/pytorch_tcn/rnn.py @@ -0,0 +1,162 @@ +import os +import warnings +from typing import Optional + +from .buffer import InternalBuffer +import torch +import torch.nn as nn + +class TemporalGRU(nn.GRU, InternalBuffer): + """ + A custom RNN layer that supports streaming inference with internal buffering, + following a similar paradigm to TemporalConv1d. + + The TemporalGRU behaves exactly like a standard nn.RNN during training. + When inference is enabled (inference=True), it uses an internal hidden-state + buffer (or one provided via a BufferIO object) to maintain continuity across + sequential calls (e.g. for streaming applications). The updated hidden state + is stored internally (or via the provided buffer_io) so that subsequent calls + will use the previous state. + + Args: + input_size (int): The number of expected features in the input. + hidden_size (int): The number of features in the hidden state. + num_layers (int, optional): Number of recurrent layers. Default: 1. + bias (bool, optional): If False, then the layer does not use bias weights. + Default: True. + batch_first (bool, optional): If True, then the input and output tensors are + provided as (batch, seq, feature). Default: False. + dropout (float, optional): If non-zero, introduces a Dropout layer on the outputs + of each RNN layer except the last layer. + bidirectional (bool, optional): If True, becomes a bidirectional RNN. + buffer (torch.Tensor, optional): Initial hidden state buffer. Typically, this is + left as None so that the RNN will use a zero + initial state. + **kwargs: Additional keyword arguments passed to nn.RNN. + """ + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ): + super(TemporalGRU, self).__init__( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + **kwargs, + ) + # Hidden state buffer for streaming inference. + # Its shape should be (num_layers * num_directions, batch, hidden_size). + # Buffer is used for streaming inference + if buffer is None: + self._buffer = torch.zeros( + num_layers * (2 if bidirectional else 1), + 1, + hidden_size,) + + elif isinstance(buffer, (int, float)): + self._buffer = torch.full( + size = (num_layers * (2 if bidirectional else 1), 1, hidden_size), + fill_value = buffer, + ) + elif not isinstance(buffer, torch.Tensor): + raise ValueError( + f""" + The argument 'buffer' must be None or of type float, + int, or torch.Tensor, but got {type(buffer)}. + """ + ) + + self._buffer = buffer + + @property + def buffer(self) -> torch.Tensor: + """ + Returns the hidden state buffer. + """ + return self._buffer + + @buffer.setter + def buffer(self, value: torch.Tensor) -> None: + """ + Sets the hidden state buffer. + """ + self._buffer = value + + def forward( + self, + x: torch.Tensor, + inference: bool = False, + in_buffer: torch.Tensor = None, + buffer_io: Optional["BufferIO"] = None, + ) -> torch.Tensor: + """ + Forward pass of the TemporalGRU. + + Args: + x (torch.Tensor): Input tensor. + inference (bool, optional): If True, streaming inference is enabled + and an internal hidden state buffer is used. + Default: False. + in_buffer (torch.Tensor): This argument has been deprecated. + Use buffer_io instead. + buffer_io (Optional[BufferIO]): An object that holds a 'buffer' attribute. + This can be used to externally manage the hidden state. + + Returns: + torch.Tensor: The output sequence of the RNN. + """ + if in_buffer is not None: + raise ValueError( + """ + The argument 'in_buffer' was removed. + Instead, you should pass the hidden state buffer as a BufferIO object + to the argument 'buffer_io'. + """ + ) + + if not inference: + # In training mode, perform a standard forward pass. + output, _ = super(TemporalGRU, self).forward(x) + return output + else: + # In inference (streaming) mode, use the stored hidden state buffer. + if buffer_io is None: + h0 = self._buffer + else: + h0 = buffer_io.next_in_buffer() + if h0 is None: + h0 = self._buffer + buffer_io.append_internal_buffer(h0) + + # If no buffer is present, default to None so that nn.RNN uses zero initial state. + if h0 is None: + output, hidden = super(TemporalGRU, self).forward(x) + else: + output, hidden = super(TemporalGRU, self).forward(x, h0) + + # Update the buffer with the new hidden state. + if buffer_io is None: + self._buffer = hidden + else: + buffer_io.append_out_buffer( hidden ) + + + return output + + def reset_buffer(self) -> None: + """ + Resets the internal hidden state buffer to None. + """ + self._buffer = None \ No newline at end of file diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index bc75e3b..373e41d 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -22,9 +22,9 @@ from typing import Union from typing import Optional from collections.abc import Iterable -from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d +from pytorch_tcn.conv import TemporalConv1d from pytorch_tcn.pad import TemporalPad1d -from pytorch_tcn.buffer import BufferIO +from pytorch_tcn.buffer import BufferIO, InternalBuffer activation_fn = dict( @@ -173,7 +173,7 @@ def _init_weights(m): def reset_buffers(self): def _reset_buffer(x): - if isinstance(x, (TemporalPad1d,) ): + if isinstance(x, (InternalBuffer,) ): x.reset_buffer() self.apply(_reset_buffer) return @@ -184,7 +184,7 @@ def get_buffers(self): """ buffers = [] def _get_buffers(x): - if isinstance(x, (TemporalPad1d,) ): + if isinstance(x, (InternalBuffer,) ): buffers.append(x.buffer) self.apply(_get_buffers) return buffers @@ -215,7 +215,7 @@ def set_buffers(self, buffers): Set all buffers of the network in the order they were created. """ def _set_buffers(x): - if isinstance(x, (TemporalPad1d,) ): + if isinstance(x, (InternalBuffer,) ): x.buffer = buffers.pop(0) self.apply(_set_buffers) return diff --git a/tests/unit/test_rnn.py b/tests/unit/test_rnn.py new file mode 100644 index 0000000..354b4c6 --- /dev/null +++ b/tests/unit/test_rnn.py @@ -0,0 +1,215 @@ +import os +import tempfile +import torch +import torch.nn as nn +import unittest + +from pytorch_tcn.buffer import BufferIO +# Replace the import below with the appropriate path to your TemporalGRU implementation. +from pytorch_tcn.rnn import TemporalGRU + + +class RNNModel(nn.Module): + """ + A simple model composed of two TemporalGRU layers. It expects an external buffer list, + which is wrapped in a BufferIO object. This is analogous to the ConvModel for the temporal convs. + """ + def __init__(self, layer_1, layer_2): + super(RNNModel, self).__init__() + self.rnn1 = layer_1 + self.rnn2 = layer_2 + + def forward(self, x, in_buffers): + buffer_io = BufferIO(in_buffers=in_buffers) + with torch.no_grad(): + x = self.rnn1(x, inference=True, buffer_io=buffer_io) + x = self.rnn2(x, inference=True, buffer_io=buffer_io) + out_buffers = buffer_io.out_buffers + return x, out_buffers + + +class TemporalGRUTest(unittest.TestCase): + + def test_rnn_streaming_with_internal_buffer(self): + # Use batch_first=True so that x has shape (batch, seq, features) + batch_size = 1 + seq_len = 32 + input_size = 3 + hidden_size = 16 + num_layers = 1 + + input_tensor = torch.randn(batch_size, seq_len, input_size) + + # Define the TemporalGRU layer. + rnn_layer = TemporalGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + # For testing internal buffering, we initialize the buffer explicitly. + initial_buffer = torch.zeros(num_layers, batch_size, hidden_size) + rnn_layer.buffer = initial_buffer.clone() + + # Create a standard RNN (with identical weights) to compute a full-sequence reference. + standard_rnn = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + standard_rnn.load_state_dict(rnn_layer.state_dict()) + full_output, _ = standard_rnn(input_tensor) + + # Now run streaming inference one time-step at a time. + rnn_layer.reset_buffer() + rnn_layer.buffer = initial_buffer.clone() + streamed_outputs = [] + for t in range(seq_len): + input_slice = input_tensor[:, t:t+1, :] # shape: (batch, 1, input_size) + output_slice = rnn_layer(input_slice, inference=True) + streamed_outputs.append(output_slice) + # Concatenate the per-time-step outputs along the time dimension. + streamed_output = torch.cat(streamed_outputs, dim=1) + + self.assertEqual(streamed_output.shape, full_output.shape) + self.assertTrue(torch.allclose(streamed_output, full_output, atol=1e-5)) + + # Verify that reset_buffer() clears the internal hidden state. + rnn_layer.reset_buffer() + self.assertIsNone(rnn_layer.buffer) + + def test_rnn_streaming_with_external_buffer(self): + batch_size = 1 + seq_len = 32 + input_size = 3 + hidden_size = 16 + num_layers = 1 + + input_tensor = torch.randn(batch_size, seq_len, input_size) + + rnn_layer = TemporalGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + + # Build a reference RNN with the same parameters/weights. + standard_rnn = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + standard_rnn.load_state_dict(rnn_layer.state_dict()) + full_output, _ = standard_rnn(input_tensor) + + # Create an external buffer (hidden state) with the correct shape. + initial_buffer = torch.zeros(num_layers, batch_size, hidden_size) + buffer_io = BufferIO(in_buffers=[initial_buffer.clone()]) + + # Run streaming inference one time-step at a time using external buffers. + rnn_layer.reset_buffer() # Ensure internal buffer is not used. + streamed_outputs = [] + for t in range(seq_len): + input_slice = input_tensor[:, t:t+1, :] + output_slice = rnn_layer(input_slice, inference=True, buffer_io=buffer_io) + streamed_outputs.append(output_slice) + buffer_io.step() # Advance the buffer pointer. + streamed_output = torch.cat(streamed_outputs, dim=1) + + self.assertEqual(streamed_output.shape, full_output.shape) + self.assertTrue(torch.allclose(streamed_output, full_output, atol=1e-5)) + + def test_rnn_deprecated_in_buffer(self): + # Ensure that using the deprecated "in_buffer" argument raises an error. + input_tensor = torch.randn(1, 1, 3) + rnn_layer = TemporalGRU(input_size=3, hidden_size=4, batch_first=True) + with self.assertRaises(ValueError): + rnn_layer(input_tensor, inference=True, in_buffer=torch.zeros(1, 1, 4)) + + def test_rnn_streaming_with_onnx(self): + try: + import onnxruntime as ort + except ImportError: + self.skipTest("onnxruntime not available") + + batch_size = 1 + seq_len = 32 + input_size = 3 + hidden_size = 16 + num_layers = 1 + + # Build a two-layer model using TemporalGRU. + rnn_layer_1 = TemporalGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + rnn_layer_2 = TemporalGRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + model = RNNModel( + layer_1=rnn_layer_1, + layer_2=rnn_layer_2, + ) + # Create external buffers for each layer. + initial_buffer_1 = torch.zeros(num_layers, batch_size, hidden_size) + initial_buffer_2 = torch.zeros(num_layers, batch_size, hidden_size) + in_buffers = [initial_buffer_1.clone(), initial_buffer_2.clone()] + + # Run full inference on the entire sequence as reference. + input_tensor = torch.randn(batch_size, seq_len, input_size) + full_output, _ = model(input_tensor, in_buffers) + + # Export the model to ONNX using a single time-step. + input_slice = input_tensor[:, :1, :] + with tempfile.TemporaryDirectory() as temp_dir: + onnx_model_name = os.path.join(temp_dir, "test_rnn_model.onnx") + torch.onnx.export( + model=model, + args=(input_slice, in_buffers), + f=onnx_model_name, + input_names=['in_x', 'in_buffer_1', 'in_buffer_2'], + output_names=['out_x', 'out_buffer_1', 'out_buffer_2'], + opset_version=9, + export_params=True, + ) + + ort_session = ort.InferenceSession(onnx_model_name) + onnx_stream = [] + # Reset the buffers for streaming inference. + for t in range(seq_len): + input_slice = input_tensor[:, t:t+1, :] + # Run the model (reference streaming inference). + ref_output, out_buffers = model(input_slice, in_buffers) + + # Run the ONNX model. + ort_inputs = { + 'in_x': input_slice.numpy(), + 'in_buffer_1': in_buffers[0].numpy(), + 'in_buffer_2': in_buffers[1].numpy(), + } + onnx_outputs = ort_session.run(None, ort_inputs) + onnx_output_slice = torch.tensor(onnx_outputs[0]) + onnx_out_buffers = [torch.tensor(b) for b in onnx_outputs[1:]] + + onnx_stream.append(onnx_output_slice) + # Compare the model output and buffers with the ONNX outputs. + self.assertTrue(torch.allclose(ref_output, onnx_output_slice, atol=1e-5)) + for ref_buf, onnx_buf in zip(out_buffers, onnx_out_buffers): + self.assertTrue(torch.allclose(ref_buf, onnx_buf, atol=1e-5)) + # Update the buffers for the next step. + in_buffers = onnx_out_buffers + + streamed_output = torch.cat(onnx_stream, dim=1) + self.assertTrue(torch.allclose(full_output, streamed_output, atol=1e-5)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file