@@ -55,7 +55,7 @@ def _is_node_supported_u55(self, node: fx.Node):
5555
5656 C_in = shape_in [1 ]
5757 C_out = shape_out [1 ]
58- if (C_in == group ) and (C_out % C_in ) == 0 :
58+ if (C_in == group ) and (C_out % C_in ) == 0 and len ( shape_in ) <= 4 :
5959 # Depthwise convolution
6060 for dim in shape_in [1 :]:
6161 if not 1 <= dim <= 65536 :
@@ -74,13 +74,19 @@ def _is_node_supported_u55(self, node: fx.Node):
7474
7575 kernel_w = kernel [2 ]
7676 kernel_h = kernel [3 ] if len (kernel ) > 3 else 1
77+ kernel_z = kernel [4 ] if len (kernel ) > 4 else 1
7778 # Kernel condition misses constraint on sum of absolute weights
7879 if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096 :
7980 self .reporter .report_reject (
8081 node ,
8182 f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({ kernel_w } , { kernel_h } )" ,
8283 )
8384 return False
85+ if kernel_z != 1 :
86+ self .reporter .report_reject (
87+ node , f"Convolution3d needs to have kernel_z==1, got { kernel_z } ."
88+ )
89+ return False
8490
8591 if not self ._stride_condition (node ):
8692 self .reporter .report_reject (
@@ -107,6 +113,14 @@ def _stride_condition(self, node: fx.Node) -> bool:
107113 if len (strides ) == 1 :
108114 strides = [strides [0 ]] * 2
109115
116+ if len (strides ) > 2 :
117+ stride_z = strides [2 ]
118+ if stride_z > 1 :
119+ self .reporter .report_reject (
120+ node , f"Convolution3d only supports stride_z<=1, got { stride_z } ."
121+ )
122+ return False
123+
110124 for stride , dilation in zip (strides , dilations ):
111125 stride_condition = 1 <= stride <= 3
112126 dilation_condition = (not has_padding ) and (dilation == 1 )
0 commit comments