@@ -55,7 +55,7 @@ def _is_node_supported_u55(self, node: fx.Node):
55
55
56
56
C_in = shape_in [1 ]
57
57
C_out = shape_out [1 ]
58
- if (C_in == group ) and (C_out % C_in ) == 0 :
58
+ if (C_in == group ) and (C_out % C_in ) == 0 and len ( shape_in ) <= 4 :
59
59
# Depthwise convolution
60
60
for dim in shape_in [1 :]:
61
61
if not 1 <= dim <= 65536 :
@@ -74,13 +74,19 @@ def _is_node_supported_u55(self, node: fx.Node):
74
74
75
75
kernel_w = kernel [2 ]
76
76
kernel_h = kernel [3 ] if len (kernel ) > 3 else 1
77
+ kernel_z = kernel [4 ] if len (kernel ) > 4 else 1
77
78
# Kernel condition misses constraint on sum of absolute weights
78
79
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096 :
79
80
self .reporter .report_reject (
80
81
node ,
81
82
f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({ kernel_w } , { kernel_h } )" ,
82
83
)
83
84
return False
85
+ if kernel_z != 1 :
86
+ self .reporter .report_reject (
87
+ node , f"Convolution3d needs to have kernel_z==1, got { kernel_z } ."
88
+ )
89
+ return False
84
90
85
91
if not self ._stride_condition (node ):
86
92
self .reporter .report_reject (
@@ -107,6 +113,14 @@ def _stride_condition(self, node: fx.Node) -> bool:
107
113
if len (strides ) == 1 :
108
114
strides = [strides [0 ]] * 2
109
115
116
+ if len (strides ) > 2 :
117
+ stride_z = strides [2 ]
118
+ if stride_z > 1 :
119
+ self .reporter .report_reject (
120
+ node , f"Convolution3d only supports stride_z<=1, got { stride_z } ."
121
+ )
122
+ return False
123
+
110
124
for stride , dilation in zip (strides , dilations ):
111
125
stride_condition = 1 <= stride <= 3
112
126
dilation_condition = (not has_padding ) and (dilation == 1 )
0 commit comments