@@ -39,16 +39,22 @@ def version_1(cls, ctx, node, **kwargs):
39
39
seed = node .get_attr ("seed" )
40
40
node .set_attr ("seed" , float (seed .f ))
41
41
if len (node .input ) > 0 :
42
- if node .inputs [0 ].is_const ():
43
- shape = node .inputs [0 ].get_tensor_value ()
44
- ctx .remove_input (node , node .input [0 ])
45
- node .set_attr ("shape" , shape )
46
- ctx .set_shape (node .output [0 ], shape )
47
- else :
48
- cast_node = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
49
- const_node = ctx .make_node ("ConstantOfShape" , cast_node .output )
50
- node .input = const_node .output
51
- node .type = node .type + 'Like'
42
+ shape = node .inputs [0 ].get_tensor_value ()
43
+ ctx .remove_input (node , node .input [0 ])
44
+ node .set_attr ("shape" , shape )
45
+ ctx .set_shape (node .output [0 ], shape )
46
+
47
+ @classmethod
48
+ def version_9 (cls , ctx , node , ** kwargs ):
49
+ if node .inputs [0 ].is_const ():
50
+ version_1 (cls , ctx , node , ** kwargs )
51
+ else :
52
+ seed = node .get_attr ("seed" )
53
+ node .set_attr ("seed" , float (seed .f ))
54
+ cast_node = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
55
+ const_node = ctx .make_node ("ConstantOfShape" , cast_node .output )
56
+ node .input = const_node .output
57
+ node .type = node .type + 'Like'
52
58
53
59
54
60
@tf_op (["RandomNormalLike" , "RandomUniformLike" ])
0 commit comments