@@ -890,29 +890,39 @@ def unpack_op(ctx, node, name, args):
890
890
891
891
def onehot_op (ctx , node , name , args ):
892
892
# until there is no onehot op in onnx, a workaround using gather from eye
893
- data = node .input [0 ]
894
- shape = ctx .get_shape (data )
895
- shapeo = ctx .get_shape (node .output [0 ])
896
- if len (shape ) != 1 :
893
+ indices_name = node .input [0 ]
894
+ indices_shape = ctx .get_shape (indices_name )
895
+ if len (indices_shape ) != 1 :
897
896
# TODO: this works for rank=1 but tensorflow supports more than this.
898
897
# Same principle should work but we need to implemtn our own eye.
899
898
raise ValueError ("onehot op: only rank1 is supported" )
900
899
axis = node .get_attr ("axis" )
901
- node . set_attr ( " axis" , axis . i )
900
+ # axis becomes axis for gather
902
901
node .set_attr ("axis" , 0 )
903
902
depth = node .inputs [1 ].get_tensor_value ()[0 ]
904
903
on = node .inputs [2 ].get_tensor_value ()[0 ]
905
904
off = node .inputs [3 ].get_tensor_value ()[0 ]
906
905
dtype = node .inputs [2 ].get_tensor_type ()
907
- del node . input [:]
908
- eye = np . eye ( depth , dtype = dtype ) * on
909
- if off != 0 :
906
+ eye = np . eye ( depth , dtype = dtype )
907
+ if on != 0 :
908
+ eye [ eye == 1 ] = on
910
909
eye [eye == 0 ] = off
910
+ else :
911
+ eye [eye == 0 ] = off
912
+ eye [eye == 1 ] = on
911
913
const_name = utils .make_name (node .name )
912
914
ctx .make_const (const_name , "Const" , eye )
915
+ # setup gather inputs
916
+ del node .input [:]
913
917
node .input .append (const_name )
914
- node .input .append (data )
918
+ node .input .append (indices_name )
915
919
node .type = "Gather"
920
+ if axis .i == 0 :
921
+ # TODO: revisit for rank > 1
922
+ name = utils .make_name (node .name )
923
+ transpose_op = ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], name )
924
+ ctx .copy_shape (node .output [0 ], transpose_op .output [0 ])
925
+ return [node , transpose_op ]
916
926
return node
917
927
918
928
0 commit comments