@@ -853,6 +853,8 @@ def version_4(cls, ctx, node, **kwargs):
853
853
utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
854
854
"only support same blocksize at different dims" )
855
855
856
+ shapes = [ctx .get_shape (node .output [0 ])]
857
+ dtypes = [ctx .get_dtype (node .output [0 ])]
856
858
ctx .remove_node (node .name )
857
859
858
860
# implement pads logic, the data format is NHWC
@@ -866,4 +868,5 @@ def version_4(cls, ctx, node, **kwargs):
866
868
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
867
869
trans1 = ctx .make_node ("Transpose" , pad_op .output , {"perm" : [3 , 0 , 1 , 2 ]})
868
870
reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
869
- ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]}, name = node .name , outputs = node .output )
871
+ ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]}, name = node .name , outputs = node .output ,
872
+ shapes = shapes , dtypes = dtypes )
0 commit comments