@@ -213,6 +213,9 @@ def __init__(self,
213
213
setattr (self , name , param )
214
214
215
215
def get_active_filter (self , in_nc , out_nc , kernel_size ):
216
+ ### Unsupport for asymmetric kernels
217
+ if self ._filter_size [0 ] != self ._filter_size [1 ]:
218
+ return self .weight [:out_nc , :in_nc , :, :]
216
219
start , end = compute_start_end (self ._filter_size [0 ], kernel_size )
217
220
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
218
221
filters = self .weight [:out_nc , :in_nc , start :end , start :end ]
@@ -285,6 +288,10 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
285
288
ks = int (self ._filter_size [0 ]) if kernel_size == None else int (
286
289
kernel_size )
287
290
291
+ if kernel_size is not None and self ._filter_size [
292
+ 0 ] != self ._filter_size [1 ]:
293
+ _logger .error ("Searching for asymmetric kernels is NOT supported" )
294
+
288
295
groups , weight_in_nc , weight_out_nc = self .get_groups_in_out_nc (in_nc ,
289
296
out_nc )
290
297
@@ -513,6 +520,9 @@ def __init__(self,
513
520
setattr (self , name , param )
514
521
515
522
def get_active_filter (self , in_nc , out_nc , kernel_size ):
523
+ ### Unsupport for asymmetric kernels
524
+ if self ._filter_size [0 ] != self ._filter_size [1 ]:
525
+ return self .weight [:out_nc , :in_nc , :, :]
516
526
start , end = compute_start_end (self ._filter_size [0 ], kernel_size )
517
527
filters = self .weight [:in_nc , :out_nc , start :end , start :end ]
518
528
if self .transform_kernel != False and kernel_size < self ._filter_size [
@@ -584,6 +594,10 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
584
594
ks = int (self ._filter_size [0 ]) if kernel_size == None else int (
585
595
kernel_size )
586
596
597
+ if kernel_size is not None and self ._filter_size [
598
+ 0 ] != self ._filter_size [1 ]:
599
+ _logger .error ("Searching for asymmetric kernels is NOT supported" )
600
+
587
601
groups , weight_in_nc , weight_out_nc = self .get_groups_in_out_nc (in_nc ,
588
602
out_nc )
589
603
0 commit comments