Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ model = TCN(
lookahead: int = 0,
output_projection: Optional[ int ] = None,
output_activation: Optional[ str ] = None,
force_residual_conv: bool = False,
use_separate_skip_connection_output: bool = False,
skip_connection_operation: str = 'sum'
)
# Continue to train/use model for your task
```
Expand All @@ -56,7 +59,7 @@ The order of output dimensions will be the same as for the input tensors.
- `num_inputs`: The number of input channels, should be equal to the feature dimension of your data.
- `num_channels`: A list or array that contains the number of feature channels in each residual block of the network.
- `kernel_size`: The size of the convolution kernel used by the convolutional layers. Good starting points may be 2-8. If the prediction task requires large context sizes, larger kernel size values may be appropriate.
- `dilations`: If None, the dilation sizes will be calculated via 2^(1...n) for the residual blocks 1 to n. This is the standard way to do it. However, if you need a custom list of dilation sizes for whatever reason you could pass such a list or array to the argument.
- `dilations`: If None, the dilation sizes will be calculated via 2^(1...n) for the residual blocks 1 to n. This is the standard way to do it. However, if you need a custom list of dilation sizes for whatever reason you could pass such a list or array to the argument. Elements of this arraylike argument may either be a single integer or another arraylike. If an element is a single integer, the corresponding TCN block will have a network architecture as described in figure 1b) of [Bai et al.](https://arxiv.org/abs/1803.01271). If an element is itself an arraylike, the corresponding TCN block will have an architecture similar to what is described in figure 2 of [Pinto et al.](https://www.mdpi.com/2079-9292/10/13/1518#) with multiple dilated convolutions of different dilation rates within the same TCN block.
- `dilation_reset`: For deep TCNs the dilation size should be reset periodically, otherwise it grows exponentially and the corresponding padding becomes so large that memory overflow occurs (see [Van den Oord et al.](https://arxiv.org/pdf/1609.03499.pdf)). E.g. 'dilation_reset=16' would reset the dilation size once it reaches a value of 16, so the dilation sizes would look like this: [ 1, 2, 4, 8, 16, 1, 2, 4, ...].
- `dropout`: Is a float value between 0 and 1 that indicates the amount of inputs which are randomly set to zero during training. Usually, 0.1 is a good starting point.
- `causal`: If 'True', the dilated convolutions will be causal, which means that future information is ignored in the prediction task. This is important for real-time predictions. If set to 'False', future context will be considered for predictions.
Expand All @@ -71,6 +74,9 @@ The order of output dimensions will be the same as for the input tensors.
- `lookahead`: If not 0, causal TCNs will use a lookahead on future time frames to increase the modelling accuracy. The lookahead parameter specifies the number of future time steps that will be processed influence the prediction for a specific time step. Default is 0. Will be ignored for non-causal networks which already have the maximum lookahead possible.
- `output_projection`: If not None, the output of the TCN will be projected to the specified dimension via a 1x1 convolution. This may be useful if the output of the TCN is supposed to be of a different dimension than the input or if the last activation should be linear. If None, no projection will be performed. The default is 'None'.
- `output_activation`: If not None, the output of the TCN will be passed through the specified activation function. This maybe useful to etablish a classification head via softmax etc. If None, no activation will be performed. The default is 'None'.
- `force_residual_conv`: If 'True', the optional 1x1 Convolution will always be calculated on the residual of each TCN block, even if the amount of input and output channels of the TCN block are identical. The default is 'False'.
- `use_separate_skip_connection_output`: Determines whether the outputs of the skip connections should be secondary output (if 'True') in addition to the regular output of the TNC or whether the output of the skip connections should be primary output (if 'False'). This flag only has an effect if parameter use_skip_connections is 'True'. The default is 'False'.
- `skip_connection_operation`: The operation to apply when combining the skip connection outputs. Can be 'sum' or 'concat'. When set to 'sum', the skip outputs of each layer are summed per channel. When set to 'stack', the skip outputs of each layer are concatenated, resulting in an output with an amount of channels equal to the sum of parameter num_channels. The default is 'sum'.

## Streaming Inference

Expand Down
176 changes: 109 additions & 67 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,84 +295,96 @@ def __init__(
causal,
use_norm,
activation,
kerner_initializer,
kernel_initializer,
embedding_shapes,
embedding_mode,
use_gate,
lookahead,
force_residual_conv,
):
super(TemporalBlock, self).__init__()
self.use_norm = use_norm
self.activation = activation
self.kernel_initializer = kerner_initializer
self.kernel_initializer = kernel_initializer
self.embedding_shapes = embedding_shapes
self.embedding_mode = embedding_mode
self.use_gate = use_gate
self.causal = causal
self.lookahead = lookahead

if self.use_gate:
conv1d_n_outputs = 2 * n_outputs
else:
conv1d_n_outputs = n_outputs

if self.causal:
self.conv1 = CausalConv1d(
in_channels=n_inputs,
out_channels=conv1d_n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
lookahead=self.lookahead,
)

self.conv2 = CausalConv1d(
in_channels=n_outputs,
out_channels=n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
lookahead=self.lookahead,
)
if isinstance(dilation, int):
dilation = [dilation]

n_multiplier_gate = 2 if self.use_gate else 1
conv1d_n_outputs = n_multiplier_gate * n_outputs

conv1 = []
for i in range(len(dilation)):
if self.causal:
conv1 += [
CausalConv1d(
in_channels=n_inputs,
out_channels=conv1d_n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[i],
lookahead=self.lookahead,
)
]
else:
conv1 += [
TemporalConv1d(
in_channels=n_inputs,
out_channels=conv1d_n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[i],
)
]

if len(dilation) == 1:
if self.causal:
self.conv2 = CausalConv1d(
in_channels=n_outputs,
out_channels=n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
lookahead=self.lookahead,
)
else:
self.conv2 = TemporalConv1d(
in_channels=n_outputs,
out_channels=n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
)
else:
self.conv1 = TemporalConv1d(
in_channels=n_inputs,
out_channels=conv1d_n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)

self.conv2 = TemporalConv1d(
in_channels=n_outputs,
self.conv2 = nn.Conv1d(
in_channels=n_outputs * len(dilation),
out_channels=n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)

kernel_size=1
)

n_norm1 = n_outputs * n_multiplier_gate * len(dilation)
if use_norm == 'batch_norm':
if self.use_gate:
self.norm1 = nn.BatchNorm1d(2 * n_outputs)
else:
self.norm1 = nn.BatchNorm1d(n_outputs)
self.norm1 = nn.BatchNorm1d(n_norm1)
self.norm2 = nn.BatchNorm1d(n_outputs)
elif use_norm == 'layer_norm':
if self.use_gate:
self.norm1 = nn.LayerNorm(2 * n_outputs)
else:
self.norm1 = nn.LayerNorm(n_outputs)
self.norm1 = nn.LayerNorm(n_norm1)
self.norm2 = nn.LayerNorm(n_outputs)
elif use_norm == 'weight_norm':
self.norm1 = None
self.norm2 = None
self.conv1 = weight_norm(self.conv1)
conv1 = [weight_norm(conv1[i]) for i in range(len(dilation))]
self.conv2 = weight_norm(self.conv2)
elif use_norm is None:
self.norm1 = None
self.norm2 = None

self.conv1 = nn.ModuleList(conv1)

if isinstance( self.activation, str ):
self.activation1 = activation_fn[ self.activation ]()
self.activation2 = activation_fn[ self.activation ]()
Expand All @@ -387,14 +399,12 @@ def __init__(

self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

self.downsample = nn.Conv1d(n_inputs, n_outputs, 1, padding=0) if n_inputs != n_outputs else None

do_downsample = n_inputs != n_outputs or force_residual_conv
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1, padding=0) if do_downsample else None

if self.embedding_shapes is not None:
if self.use_gate:
embedding_layer_n_outputs = 2 * n_outputs
else:
embedding_layer_n_outputs = n_outputs
embedding_layer_n_outputs = n_outputs * n_multiplier_gate * len(dilation)

self.embedding_projection_1 = nn.Conv1d(
in_channels = sum( [ shape[0] for shape in self.embedding_shapes ] ),
Expand All @@ -416,10 +426,11 @@ def init_weights(self):
name=self.kernel_initializer,
activation=self.activation,
)
initialize(
self.conv1.weight,
**kwargs
)
for i in range(len(self.conv1)):
initialize(
self.conv1[i].weight,
**kwargs
)
initialize(
self.conv2.weight,
**kwargs
Expand Down Expand Up @@ -494,16 +505,17 @@ def forward(
embeddings,
inference,
):
out = self.conv1(x, inference=inference)
out = self.apply_norm( self.norm1, out )
out = [self.conv1[i](x, inference=inference) for i in range(len(self.conv1))]
out = torch.cat(out, dim=1)
out = self.apply_norm(self.norm1, out)

if embeddings is not None:
out = self.apply_embeddings( out, embeddings )

out = self.activation1(out)
out = self.dropout1(out)

out = self.conv2(out, inference=inference)
out = self.conv2(out, inference=inference) if len(self.conv1) == 1 else self.conv2(out)
out = self.apply_norm( self.norm2, out )
out = self.activation2(out)
out = self.dropout2(out)
Expand Down Expand Up @@ -550,6 +562,9 @@ def __init__(
lookahead: int = 0,
output_projection: Optional[ int ] = None,
output_activation: Optional[ str ] = None,
force_residual_conv: bool = False,
use_separate_skip_connection_output: bool = False,
skip_connection_operation: str = 'sum'
):
super(TCN, self).__init__()

Expand All @@ -558,6 +573,7 @@ def __init__(

self.allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None]
self.allowed_input_shapes = ['NCL', 'NLC']
self.allowed_skip_connection_operations = ['sum', 'concat']

_check_generic_input_arg( causal, 'causal', [True, False] )
_check_generic_input_arg( use_norm, 'use_norm', self.allowed_norm_values )
Expand All @@ -568,6 +584,8 @@ def __init__(
_check_generic_input_arg( embedding_mode, 'embedding_mode', ['add', 'concat'] )
_check_generic_input_arg( use_gate, 'use_gate', [True, False] )
_check_activation_arg(output_activation, 'output_activation')
_check_generic_input_arg(skip_connection_operation, 'skip_connection_operation',
self.allowed_skip_connection_operations)

if dilations is None:
if dilation_reset is None:
Expand All @@ -584,6 +602,7 @@ def __init__(
self.activation = activation
self.kernel_initializer = kernel_initializer
self.use_skip_connections = use_skip_connections
self.use_separate_skip_connection_output = use_separate_skip_connection_output
self.input_shape = input_shape
self.embedding_shapes = embedding_shapes
self.embedding_mode = embedding_mode
Expand All @@ -592,6 +611,7 @@ def __init__(
self.lookahead = lookahead
self.output_projection = output_projection
self.output_activation = output_activation
self.skip_connection_operation = skip_connection_operation

if embedding_shapes is not None:
if isinstance(embedding_shapes, Iterable):
Expand Down Expand Up @@ -627,7 +647,7 @@ def __init__(
self.downsample_skip_connection = nn.ModuleList()
for i in range( len( num_channels ) ):
# Downsample layer output dim to network output dim if needed
if num_channels[i] != num_channels[-1]:
if skip_connection_operation == 'sum' and num_channels[i] != num_channels[-1]:
self.downsample_skip_connection.append(
nn.Conv1d( num_channels[i], num_channels[-1], 1 )
)
Expand Down Expand Up @@ -661,11 +681,12 @@ def __init__(
causal=causal,
use_norm=use_norm,
activation=activation,
kerner_initializer=self.kernel_initializer,
kernel_initializer=self.kernel_initializer,
embedding_shapes=self.embedding_shapes,
embedding_mode=self.embedding_mode,
use_gate=self.use_gate,
lookahead=self.lookahead,
force_residual_conv=force_residual_conv
)
]

Expand Down Expand Up @@ -736,8 +757,17 @@ def forward(
if index < len( self.network ) - 1:
skip_connections.append( skip_out )
skip_connections.append( x )
x = torch.stack( skip_connections, dim=0 ).sum( dim=0 )
x = self.activation_skip_out( x )
if self.skip_connection_operation == 'sum':
x_skip = torch.stack( skip_connections, dim=0 ).sum( dim=0 )
elif self.skip_connection_operation == 'concat':
x_skip = torch.cat( skip_connections, dim=1 )
else:
raise NotImplementedError(
f"skip_connection_operation '{self.skip_connection_operation}' is not implemented!"
)
x_skip = self.activation_skip_out( x_skip )
if not self.use_separate_skip_connection_output:
x = x_skip
else:
for layer in self.network:
#print( 'TCN, embeddings:', embeddings.shape )
Expand All @@ -748,13 +778,25 @@ def forward(
)
if self.projection_out is not None:
x = self.projection_out( x )
if self.use_skip_connections and self.use_separate_skip_connection_output:
x_skip = self.projection_out( x_skip )
if self.activation_out is not None:
x = self.activation_out( x )
if self.use_skip_connections and self.use_separate_skip_connection_output:
x_skip = self.activation_out( x_skip )
if inference and self.lookahead > 0:
x = x[ :, :, self.lookahead: ]
if self.use_skip_connections and self.use_separate_skip_connection_output:
x_skip = x_skip[ :, :, self.lookahead: ]
if self.input_shape == 'NLC':
x = x.transpose(1, 2)
return x
if self.use_skip_connections and self.use_separate_skip_connection_output:
x_skip = x_skip.transpose(1, 2)

if self.use_skip_connections and self.use_separate_skip_connection_output:
return x, x_skip
else:
return x

def inference(
self,
Expand Down
Loading