diff --git a/pytorch_tcn/conv.py b/pytorch_tcn/conv.py index 37840a5..24ad166 100644 --- a/pytorch_tcn/conv.py +++ b/pytorch_tcn/conv.py @@ -131,11 +131,15 @@ def inference(self, *args, **kwargs): ) return - def reset_buffer(self): - self.padder.reset_buffer() + def reset_buffer(self, batch_size: int = 1): + """ + Reset the internal history buffer. + If you plan to process N parallel real‑time streams in the next + call(s) pass batch_size=N so the buffer is resized accordingly. + """ + self.padder.reset_buffer(batch_size=batch_size) return - class TemporalConvTranspose1d(nn.ConvTranspose1d): def __init__( self, @@ -329,6 +333,7 @@ def inference(self, *args, **kwargs): ) return - def reset_buffer(self): - self.padder.reset_buffer() + def reset_buffer(self, batch_size: int = 1): + self.padder.reset_buffer(batch_size=batch_size) + return diff --git a/pytorch_tcn/pad.py b/pytorch_tcn/pad.py index e673768..d3a8793 100644 --- a/pytorch_tcn/pad.py +++ b/pytorch_tcn/pad.py @@ -129,28 +129,38 @@ def pad_inference( """ ) - if x.shape[0] != 1: + batch_size = x.size(0) + + def _align_buffer(buf: torch.Tensor, name: str) -> torch.Tensor: + # buf is None -> handled by caller + if buf is None: + return None + if buf.size(0) == batch_size: + return buf + if buf.size(0) == 1: + # replicate the single history for every element in the batch + return buf.repeat(batch_size, 1, 1) raise ValueError( - f""" - Streaming inference requires a batch size - of 1, but batch size is {x.shape[0]}. - """ - ) + f"Batch mismatch between input (N={batch_size}) and " + f"{name} (N={buf.size(0)}). Either supply a buffer of the " + f"same batch size or call .reset_buffer(batch_size=N)." + ) + if buffer_io is None: - in_buffer = self.buffer + in_buffer = _align_buffer(self.buffer, 'internal buffer') else: in_buffer = buffer_io.next_in_buffer() - if in_buffer is None: + if in_buffer is None: # first iteration, fall back in_buffer = self.buffer - buffer_io.append_internal_buffer( in_buffer ) - - x = torch.cat( - (in_buffer, x), - -1, - ) - - out_buffer = x[ ..., -self.pad_len: ] + in_buffer = _align_buffer(in_buffer, 'in_buffer from BufferIO') + buffer_io.append_internal_buffer(in_buffer) + + # pad the current input with the previous history + x = torch.cat((in_buffer, x), dim=-1) + + # remember the most recent history for the *next* call + out_buffer = x[..., -self.pad_len:] if buffer_io is None: self.buffer = out_buffer else: @@ -170,12 +180,25 @@ def forward( x = self.pad(x) return x - def reset_buffer(self): - self.buffer.zero_() - if self.buffer.shape[-1] != self.pad_len: - raise ValueError( - f""" - Buffer shape {self.buffer.shape} does not match the expected - shape (1, {self.in_channels}, {self.pad_len}). - """ - ) \ No newline at end of file + def reset_buffer(self, batch_size: int = 1) -> None: + """ + Reset the streaming buffer to zeros. + + Parameters + ---------- + batch_size : int, default 1 + Number of parallel streams that will be processed in the next + call(s). If this differs from the current buffer’s batch + dimension, the buffer is re‑allocated accordingly. + """ + if self.buffer.size(0) != batch_size: + self.buffer = torch.zeros( + batch_size, + self.buffer.size(1), # channels + self.pad_len, + device=self.buffer.device, + dtype=self.buffer.dtype, + ) + else: + self.buffer.zero_() + return diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index bc75e3b..0f89caf 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -171,12 +171,22 @@ def _init_weights(m): return - def reset_buffers(self): - def _reset_buffer(x): - if isinstance(x, (TemporalPad1d,) ): - x.reset_buffer() - self.apply(_reset_buffer) - return + def reset_buffers(self, batch_size: int = 1): + """ + Reset the streaming buffers of every TemporalPad1d module. + + Parameters + ---------- + batch_size : int, default 1 + Number of parallel streams that will be processed in the next + inference call(s). The underlying padders re‑allocate their + buffer if the size does not match. + """ + def _reset_buffer(x): + if isinstance(x, (TemporalPad1d,) ): + x.reset_buffer(batch_size=batch_size) + self.apply(_reset_buffer) + return def get_buffers(self): """