Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pytorch_tcn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,15 @@ def inference(self, *args, **kwargs):
)
return

def reset_buffer(self):
self.padder.reset_buffer()
def reset_buffer(self, batch_size: int = 1):
"""
Reset the internal history buffer.
If you plan to process N parallel real‑time streams in the next
call(s) pass batch_size=N so the buffer is resized accordingly.
"""
self.padder.reset_buffer(batch_size=batch_size)
return


class TemporalConvTranspose1d(nn.ConvTranspose1d):
def __init__(
self,
Expand Down Expand Up @@ -329,6 +333,7 @@ def inference(self, *args, **kwargs):
)
return

def reset_buffer(self):
self.padder.reset_buffer()
def reset_buffer(self, batch_size: int = 1):
self.padder.reset_buffer(batch_size=batch_size)
return

73 changes: 48 additions & 25 deletions pytorch_tcn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,28 +129,38 @@ def pad_inference(
"""
)

if x.shape[0] != 1:
batch_size = x.size(0)

def _align_buffer(buf: torch.Tensor, name: str) -> torch.Tensor:
# buf is None -> handled by caller
if buf is None:
return None
if buf.size(0) == batch_size:
return buf
if buf.size(0) == 1:
# replicate the single history for every element in the batch
return buf.repeat(batch_size, 1, 1)
raise ValueError(
f"""
Streaming inference requires a batch size
of 1, but batch size is {x.shape[0]}.
"""
)
f"Batch mismatch between input (N={batch_size}) and "
f"{name} (N={buf.size(0)}). Either supply a buffer of the "
f"same batch size or call .reset_buffer(batch_size=N)."
)


if buffer_io is None:
in_buffer = self.buffer
in_buffer = _align_buffer(self.buffer, 'internal buffer')
else:
in_buffer = buffer_io.next_in_buffer()
if in_buffer is None:
if in_buffer is None: # first iteration, fall back
in_buffer = self.buffer
buffer_io.append_internal_buffer( in_buffer )

x = torch.cat(
(in_buffer, x),
-1,
)

out_buffer = x[ ..., -self.pad_len: ]
in_buffer = _align_buffer(in_buffer, 'in_buffer from BufferIO')
buffer_io.append_internal_buffer(in_buffer)

# pad the current input with the previous history
x = torch.cat((in_buffer, x), dim=-1)

# remember the most recent history for the *next* call
out_buffer = x[..., -self.pad_len:]
if buffer_io is None:
self.buffer = out_buffer
else:
Expand All @@ -170,12 +180,25 @@ def forward(
x = self.pad(x)
return x

def reset_buffer(self):
self.buffer.zero_()
if self.buffer.shape[-1] != self.pad_len:
raise ValueError(
f"""
Buffer shape {self.buffer.shape} does not match the expected
shape (1, {self.in_channels}, {self.pad_len}).
"""
)
def reset_buffer(self, batch_size: int = 1) -> None:
"""
Reset the streaming buffer to zeros.

Parameters
----------
batch_size : int, default 1
Number of parallel streams that will be processed in the next
call(s). If this differs from the current buffer’s batch
dimension, the buffer is re‑allocated accordingly.
"""
if self.buffer.size(0) != batch_size:
self.buffer = torch.zeros(
batch_size,
self.buffer.size(1), # channels
self.pad_len,
device=self.buffer.device,
dtype=self.buffer.dtype,
)
else:
self.buffer.zero_()
return
22 changes: 16 additions & 6 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,22 @@ def _init_weights(m):

return

def reset_buffers(self):
def _reset_buffer(x):
if isinstance(x, (TemporalPad1d,) ):
x.reset_buffer()
self.apply(_reset_buffer)
return
def reset_buffers(self, batch_size: int = 1):
"""
Reset the streaming buffers of every TemporalPad1d module.

Parameters
----------
batch_size : int, default 1
Number of parallel streams that will be processed in the next
inference call(s). The underlying padders re‑allocate their
buffer if the size does not match.
"""
def _reset_buffer(x):
if isinstance(x, (TemporalPad1d,) ):
x.reset_buffer(batch_size=batch_size)
self.apply(_reset_buffer)
return

def get_buffers(self):
"""
Expand Down