diff --git a/README.md b/README.md index b00ef87..da884b2 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ model = TCN( lookahead: int = 0, output_projection: Optional[ int ] = None, output_activation: Optional[ str ] = None, + force_residual_conv: bool = False, + use_separate_skip_connection_output: bool = False, + skip_connection_operation: str = 'sum' ) # Continue to train/use model for your task ``` @@ -56,7 +59,7 @@ The order of output dimensions will be the same as for the input tensors. - `num_inputs`: The number of input channels, should be equal to the feature dimension of your data. - `num_channels`: A list or array that contains the number of feature channels in each residual block of the network. - `kernel_size`: The size of the convolution kernel used by the convolutional layers. Good starting points may be 2-8. If the prediction task requires large context sizes, larger kernel size values may be appropriate. -- `dilations`: If None, the dilation sizes will be calculated via 2^(1...n) for the residual blocks 1 to n. This is the standard way to do it. However, if you need a custom list of dilation sizes for whatever reason you could pass such a list or array to the argument. +- `dilations`: If None, the dilation sizes will be calculated via 2^(1...n) for the residual blocks 1 to n. This is the standard way to do it. However, if you need a custom list of dilation sizes for whatever reason you could pass such a list or array to the argument. Elements of this arraylike argument may either be a single integer or another arraylike. If an element is a single integer, the corresponding TCN block will have a network architecture as described in figure 1b) of [Bai et al.](https://arxiv.org/abs/1803.01271). If an element is itself an arraylike, the corresponding TCN block will have an architecture similar to what is described in figure 2 of [Pinto et al.](https://www.mdpi.com/2079-9292/10/13/1518#) with multiple dilated convolutions of different dilation rates within the same TCN block. - `dilation_reset`: For deep TCNs the dilation size should be reset periodically, otherwise it grows exponentially and the corresponding padding becomes so large that memory overflow occurs (see [Van den Oord et al.](https://arxiv.org/pdf/1609.03499.pdf)). E.g. 'dilation_reset=16' would reset the dilation size once it reaches a value of 16, so the dilation sizes would look like this: [ 1, 2, 4, 8, 16, 1, 2, 4, ...]. - `dropout`: Is a float value between 0 and 1 that indicates the amount of inputs which are randomly set to zero during training. Usually, 0.1 is a good starting point. - `causal`: If 'True', the dilated convolutions will be causal, which means that future information is ignored in the prediction task. This is important for real-time predictions. If set to 'False', future context will be considered for predictions. @@ -71,6 +74,9 @@ The order of output dimensions will be the same as for the input tensors. - `lookahead`: If not 0, causal TCNs will use a lookahead on future time frames to increase the modelling accuracy. The lookahead parameter specifies the number of future time steps that will be processed influence the prediction for a specific time step. Default is 0. Will be ignored for non-causal networks which already have the maximum lookahead possible. - `output_projection`: If not None, the output of the TCN will be projected to the specified dimension via a 1x1 convolution. This may be useful if the output of the TCN is supposed to be of a different dimension than the input or if the last activation should be linear. If None, no projection will be performed. The default is 'None'. - `output_activation`: If not None, the output of the TCN will be passed through the specified activation function. This maybe useful to etablish a classification head via softmax etc. If None, no activation will be performed. The default is 'None'. +- `force_residual_conv`: If 'True', the optional 1x1 Convolution will always be calculated on the residual of each TCN block, even if the amount of input and output channels of the TCN block are identical. The default is 'False'. +- `use_separate_skip_connection_output`: Determines whether the outputs of the skip connections should be secondary output (if 'True') in addition to the regular output of the TNC or whether the output of the skip connections should be primary output (if 'False'). This flag only has an effect if parameter use_skip_connections is 'True'. The default is 'False'. +- `skip_connection_operation`: The operation to apply when combining the skip connection outputs. Can be 'sum' or 'concat'. When set to 'sum', the skip outputs of each layer are summed per channel. When set to 'stack', the skip outputs of each layer are concatenated, resulting in an output with an amount of channels equal to the sum of parameter num_channels. The default is 'sum'. ## Streaming Inference diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index 62c3022..7233678 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -295,84 +295,96 @@ def __init__( causal, use_norm, activation, - kerner_initializer, + kernel_initializer, embedding_shapes, embedding_mode, use_gate, lookahead, + force_residual_conv, ): super(TemporalBlock, self).__init__() self.use_norm = use_norm self.activation = activation - self.kernel_initializer = kerner_initializer + self.kernel_initializer = kernel_initializer self.embedding_shapes = embedding_shapes self.embedding_mode = embedding_mode self.use_gate = use_gate self.causal = causal self.lookahead = lookahead - if self.use_gate: - conv1d_n_outputs = 2 * n_outputs - else: - conv1d_n_outputs = n_outputs - - if self.causal: - self.conv1 = CausalConv1d( - in_channels=n_inputs, - out_channels=conv1d_n_outputs, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - lookahead=self.lookahead, - ) - - self.conv2 = CausalConv1d( - in_channels=n_outputs, - out_channels=n_outputs, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - lookahead=self.lookahead, - ) + if isinstance(dilation, int): + dilation = [dilation] + + n_multiplier_gate = 2 if self.use_gate else 1 + conv1d_n_outputs = n_multiplier_gate * n_outputs + + conv1 = [] + for i in range(len(dilation)): + if self.causal: + conv1 += [ + CausalConv1d( + in_channels=n_inputs, + out_channels=conv1d_n_outputs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[i], + lookahead=self.lookahead, + ) + ] + else: + conv1 += [ + TemporalConv1d( + in_channels=n_inputs, + out_channels=conv1d_n_outputs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[i], + ) + ] + if len(dilation) == 1: + if self.causal: + self.conv2 = CausalConv1d( + in_channels=n_outputs, + out_channels=n_outputs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + lookahead=self.lookahead, + ) + else: + self.conv2 = TemporalConv1d( + in_channels=n_outputs, + out_channels=n_outputs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + ) else: - self.conv1 = TemporalConv1d( - in_channels=n_inputs, - out_channels=conv1d_n_outputs, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - - self.conv2 = TemporalConv1d( - in_channels=n_outputs, + self.conv2 = nn.Conv1d( + in_channels=n_outputs * len(dilation), out_channels=n_outputs, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - + kernel_size=1 + ) + + n_norm1 = n_outputs * n_multiplier_gate * len(dilation) if use_norm == 'batch_norm': - if self.use_gate: - self.norm1 = nn.BatchNorm1d(2 * n_outputs) - else: - self.norm1 = nn.BatchNorm1d(n_outputs) + self.norm1 = nn.BatchNorm1d(n_norm1) self.norm2 = nn.BatchNorm1d(n_outputs) elif use_norm == 'layer_norm': - if self.use_gate: - self.norm1 = nn.LayerNorm(2 * n_outputs) - else: - self.norm1 = nn.LayerNorm(n_outputs) + self.norm1 = nn.LayerNorm(n_norm1) self.norm2 = nn.LayerNorm(n_outputs) elif use_norm == 'weight_norm': self.norm1 = None self.norm2 = None - self.conv1 = weight_norm(self.conv1) + conv1 = [weight_norm(conv1[i]) for i in range(len(dilation))] self.conv2 = weight_norm(self.conv2) elif use_norm is None: self.norm1 = None self.norm2 = None + self.conv1 = nn.ModuleList(conv1) + if isinstance( self.activation, str ): self.activation1 = activation_fn[ self.activation ]() self.activation2 = activation_fn[ self.activation ]() @@ -387,14 +399,12 @@ def __init__( self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) - - self.downsample = nn.Conv1d(n_inputs, n_outputs, 1, padding=0) if n_inputs != n_outputs else None + + do_downsample = n_inputs != n_outputs or force_residual_conv + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1, padding=0) if do_downsample else None if self.embedding_shapes is not None: - if self.use_gate: - embedding_layer_n_outputs = 2 * n_outputs - else: - embedding_layer_n_outputs = n_outputs + embedding_layer_n_outputs = n_outputs * n_multiplier_gate * len(dilation) self.embedding_projection_1 = nn.Conv1d( in_channels = sum( [ shape[0] for shape in self.embedding_shapes ] ), @@ -416,10 +426,11 @@ def init_weights(self): name=self.kernel_initializer, activation=self.activation, ) - initialize( - self.conv1.weight, - **kwargs - ) + for i in range(len(self.conv1)): + initialize( + self.conv1[i].weight, + **kwargs + ) initialize( self.conv2.weight, **kwargs @@ -494,8 +505,9 @@ def forward( embeddings, inference, ): - out = self.conv1(x, inference=inference) - out = self.apply_norm( self.norm1, out ) + out = [self.conv1[i](x, inference=inference) for i in range(len(self.conv1))] + out = torch.cat(out, dim=1) + out = self.apply_norm(self.norm1, out) if embeddings is not None: out = self.apply_embeddings( out, embeddings ) @@ -503,7 +515,7 @@ def forward( out = self.activation1(out) out = self.dropout1(out) - out = self.conv2(out, inference=inference) + out = self.conv2(out, inference=inference) if len(self.conv1) == 1 else self.conv2(out) out = self.apply_norm( self.norm2, out ) out = self.activation2(out) out = self.dropout2(out) @@ -550,6 +562,9 @@ def __init__( lookahead: int = 0, output_projection: Optional[ int ] = None, output_activation: Optional[ str ] = None, + force_residual_conv: bool = False, + use_separate_skip_connection_output: bool = False, + skip_connection_operation: str = 'sum' ): super(TCN, self).__init__() @@ -558,6 +573,7 @@ def __init__( self.allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None] self.allowed_input_shapes = ['NCL', 'NLC'] + self.allowed_skip_connection_operations = ['sum', 'concat'] _check_generic_input_arg( causal, 'causal', [True, False] ) _check_generic_input_arg( use_norm, 'use_norm', self.allowed_norm_values ) @@ -568,6 +584,8 @@ def __init__( _check_generic_input_arg( embedding_mode, 'embedding_mode', ['add', 'concat'] ) _check_generic_input_arg( use_gate, 'use_gate', [True, False] ) _check_activation_arg(output_activation, 'output_activation') + _check_generic_input_arg(skip_connection_operation, 'skip_connection_operation', + self.allowed_skip_connection_operations) if dilations is None: if dilation_reset is None: @@ -584,6 +602,7 @@ def __init__( self.activation = activation self.kernel_initializer = kernel_initializer self.use_skip_connections = use_skip_connections + self.use_separate_skip_connection_output = use_separate_skip_connection_output self.input_shape = input_shape self.embedding_shapes = embedding_shapes self.embedding_mode = embedding_mode @@ -592,6 +611,7 @@ def __init__( self.lookahead = lookahead self.output_projection = output_projection self.output_activation = output_activation + self.skip_connection_operation = skip_connection_operation if embedding_shapes is not None: if isinstance(embedding_shapes, Iterable): @@ -627,7 +647,7 @@ def __init__( self.downsample_skip_connection = nn.ModuleList() for i in range( len( num_channels ) ): # Downsample layer output dim to network output dim if needed - if num_channels[i] != num_channels[-1]: + if skip_connection_operation == 'sum' and num_channels[i] != num_channels[-1]: self.downsample_skip_connection.append( nn.Conv1d( num_channels[i], num_channels[-1], 1 ) ) @@ -661,11 +681,12 @@ def __init__( causal=causal, use_norm=use_norm, activation=activation, - kerner_initializer=self.kernel_initializer, + kernel_initializer=self.kernel_initializer, embedding_shapes=self.embedding_shapes, embedding_mode=self.embedding_mode, use_gate=self.use_gate, lookahead=self.lookahead, + force_residual_conv=force_residual_conv ) ] @@ -736,8 +757,17 @@ def forward( if index < len( self.network ) - 1: skip_connections.append( skip_out ) skip_connections.append( x ) - x = torch.stack( skip_connections, dim=0 ).sum( dim=0 ) - x = self.activation_skip_out( x ) + if self.skip_connection_operation == 'sum': + x_skip = torch.stack( skip_connections, dim=0 ).sum( dim=0 ) + elif self.skip_connection_operation == 'concat': + x_skip = torch.cat( skip_connections, dim=1 ) + else: + raise NotImplementedError( + f"skip_connection_operation '{self.skip_connection_operation}' is not implemented!" + ) + x_skip = self.activation_skip_out( x_skip ) + if not self.use_separate_skip_connection_output: + x = x_skip else: for layer in self.network: #print( 'TCN, embeddings:', embeddings.shape ) @@ -748,13 +778,25 @@ def forward( ) if self.projection_out is not None: x = self.projection_out( x ) + if self.use_skip_connections and self.use_separate_skip_connection_output: + x_skip = self.projection_out( x_skip ) if self.activation_out is not None: x = self.activation_out( x ) + if self.use_skip_connections and self.use_separate_skip_connection_output: + x_skip = self.activation_out( x_skip ) if inference and self.lookahead > 0: x = x[ :, :, self.lookahead: ] + if self.use_skip_connections and self.use_separate_skip_connection_output: + x_skip = x_skip[ :, :, self.lookahead: ] if self.input_shape == 'NLC': x = x.transpose(1, 2) - return x + if self.use_skip_connections and self.use_separate_skip_connection_output: + x_skip = x_skip.transpose(1, 2) + + if self.use_skip_connections and self.use_separate_skip_connection_output: + return x, x_skip + else: + return x def inference( self, diff --git a/tests/unit/test_tcn.py b/tests/unit/test_tcn.py index d68b97c..b53f7d0 100644 --- a/tests/unit/test_tcn.py +++ b/tests/unit/test_tcn.py @@ -23,6 +23,7 @@ def generate_combinations(test_args): dict( kwargs = combination_dict, expected_error = x['expected_error'], + expected_outputs = x.get('expected_outputs') ) ) @@ -71,6 +72,9 @@ def __init__(self, methodName: str = "runTest") -> None: kwargs = dict( dilations = [ [1, 2, 3, 4, 1, 2, 3, 4], + [[1, 2], [2, 4], [3, 6], [4, 8], [1, 2], [2, 4], [3, 6], [4, 8]], + [[1, 2], [2, 4], [3, 6], [4, 8], 1, 2, 3, 4], + [1, 2, 3, 4, [1, 2], [2, 4], [3, 6], [4, 8]], None, ], ), @@ -162,13 +166,58 @@ def __init__(self, methodName: str = "runTest") -> None: ), expected_error = ValueError, ), + # Test different values for force_residual_conv + dict( + kwargs = dict( force_residual_conv = [True, False] ), + expected_error = None, + ), + # Test different values for use_separate_skip_connection_output + dict( + kwargs = dict( + use_skip_connections = [True], + use_separate_skip_connection_output = [True] + ), + expected_error = None, + expected_outputs = 2 + ), + dict( + kwargs=dict( + use_skip_connections = [True, False], + use_separate_skip_connection_output = [False] + ), + expected_error = None, + expected_outputs = 1 + ), + dict( + kwargs=dict( + use_skip_connections = [False], + use_separate_skip_connection_output = [True] + ), + expected_error = None, + expected_outputs = 1 + ), + # Test different values for skip_connection_operation + dict( + kwargs=dict( + use_skip_connections = [True], + skip_connection_operation = ['sum', 'concat'] + ), + expected_error=None, + ), + dict( + kwargs=dict( + use_skip_connections = [True], + skip_connection_operation = ['foo'] + ), + expected_error = ValueError, + ), ] self.combinations = generate_combinations(self.test_args) return - def test_tcn(self, **kwargs): + def test_tcn(self, expected_outputs=None, **kwargs): tcn = TCN( num_inputs = self.num_inputs, @@ -184,9 +233,11 @@ def test_tcn(self, **kwargs): self.num_inputs, self.time_steps, ) + is_skip_operation_concat = (kwargs.get('use_skip_connections', False) + and kwargs.get('skip_connection_operation', 'sum') == 'concat') expected_shape = ( self.batch_size, - self.num_channels[-1], + sum(self.num_channels) if is_skip_operation_concat else self.num_channels[-1], self.time_steps, ) x_inference = torch.randn( @@ -196,7 +247,7 @@ def test_tcn(self, **kwargs): ) expected_shape_inference = ( 1, - self.num_channels[-1], + sum(self.num_channels) if is_skip_operation_concat else self.num_channels[-1], self.time_steps - tcn.lookahead, ) @@ -250,8 +301,11 @@ def test_tcn(self, **kwargs): embeddings_inference = None y = tcn(x, embeddings = embeddings) - - self.assertEqual( y.shape, expected_shape ) + if expected_outputs is None or expected_outputs == 1: + self.assertEqual( y.shape, expected_shape ) + else: + for i in range(expected_outputs): + self.assertEqual( y[i].shape, expected_shape ) # Testing the streaming inference mode for causal models if tcn.causal: @@ -269,8 +323,12 @@ def test_tcn(self, **kwargs): ) #print( 'y_inference shape: ', y_inference.shape) #stop - - self.assertEqual( y_inference.shape, expected_shape_inference ) + + if expected_outputs is None or expected_outputs == 1: + self.assertEqual(y_inference.shape, expected_shape_inference) + else: + for i in range(expected_outputs): + self.assertEqual(y_inference[i].shape, expected_shape_inference) # piecewise inference: tcn.reset_buffers() @@ -314,8 +372,13 @@ def test_tcn(self, **kwargs): embeddings = embeddings_frame, ) ) - y_inference_frames = torch.cat( y_inference_frames, dim = time_dimension ) - self.assertEqual( y_inference_frames.shape, expected_shape_inference ) + if expected_outputs is None or expected_outputs == 1: + y_inference_frames = torch.cat(y_inference_frames, dim=time_dimension) + self.assertEqual(y_inference_frames.shape, expected_shape_inference) + else: + for i in range(expected_outputs): + y_inference_concatenated = torch.cat([frame[i] for frame in y_inference_frames], dim=time_dimension) + self.assertEqual(y_inference_concatenated.shape, expected_shape_inference) #stop ## piecewise inference without buffer @@ -353,10 +416,10 @@ def test_tcn_grid_search(self): kwargs = test_dict['kwargs'] print( 'Testing kwargs: ', kwargs ) if test_dict['expected_error'] is None: - self.test_tcn( **kwargs ) + self.test_tcn(test_dict.get('expected_outputs'), **kwargs ) else: with self.assertRaises(test_dict['expected_error']): - self.test_tcn( **kwargs ) + self.test_tcn(test_dict.get('expected_outputs'), **kwargs ) return