@@ -130,6 +130,7 @@ def __init__(
130130 kernel_size : Union [int , Iterable [int ]],
131131 padding : Union [int , Iterable [int ]] = 0 ,
132132 stride : Union [int , Iterable [int ]] = 1 ,
133+ dilation : Union [int , Iterable [int ]] = 1 ,
133134 groups : int = 1 ,
134135 bias : bool = True ,
135136 ndim : int = 1 ,
@@ -150,6 +151,7 @@ def __init__(
150151 self .kernel_size = kernel_size
151152 self .padding = padding
152153 self .stride = stride
154+ self .dilation = dilation
153155 self .groups = groups
154156 self .use_bias = bias
155157
@@ -165,9 +167,22 @@ def __init__(
165167 )
166168
167169 kernel_size = to_ntuple (kernel_size , ndim )
168- self .weight = nn .Parameter (
169- torch .randn (out_channels , in_channels // groups , * kernel_size )
170+ dilation = to_ntuple (dilation , ndim )
171+ total_size = tuple (
172+ ((ks - 1 ) * dil + 1 )
173+ for ks , dil in zip (kernel_size , dilation )
170174 )
175+ weight = torch .zeros (out_channels , in_channels // groups , * total_size )
176+ fill = torch .randn (out_channels , in_channels // groups , * kernel_size )
177+ ids = tuple (
178+ torch .arange (0 , tot_sz , dil )
179+ for tot_sz , dil in zip (total_size , dilation )
180+ )
181+
182+ # workaround bc PyTorch doesn't support [:, :, <tensor tuple>] indexing
183+ weight [(slice (None ), slice (None ),) + torch .meshgrid (* ids )] = fill
184+
185+ self .weight = nn .Parameter (weight )
171186 self .bias = nn .Parameter (torch .randn (out_channels )) if bias else None
172187
173188 def forward (self , signal ):
0 commit comments