|
13 | 13 |
|
14 | 14 |
|
15 | 15 | def rewrite_random_normal(g, ops):
|
16 |
| - pattern = \ |
| 16 | + pattern1 = \ |
17 | 17 | OpTypePattern('Add', name='output', inputs=[
|
18 | 18 | OpTypePattern('Mul', name='input2', inputs=[
|
19 | 19 | OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
|
20 | 20 | ]), "*"
|
21 | 21 | ])
|
22 | 22 |
|
23 |
| - matcher = GraphMatcher(pattern) |
24 |
| - match_results = list(matcher.match_ops(ops)) |
25 |
| - for match in match_results: |
26 |
| - output = match.get_op('output') |
27 |
| - mean = output.inputs[1].get_tensor_value() |
28 |
| - dtype = g.get_dtype(output.output[0]) |
29 |
| - op_name = utils.make_name("RandomNormal") |
30 |
| - out_name = utils.port_name(op_name) |
31 |
| - |
32 |
| - rn_op = match.get_op('input1') |
33 |
| - seed = rn_op.get_attr('seed2').i |
34 |
| - if rn_op.inputs[0].type == "Shape": |
35 |
| - shape_node = rn_op.inputs[0] |
36 |
| - new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, |
37 |
| - attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed}) |
38 |
| - else: |
39 |
| - shape = g.get_shape(output.output[0]) |
40 |
| - new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, |
41 |
| - attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed}) |
42 |
| - |
43 |
| - g.replace_all_inputs(ops, output.output[0], new_node.output[0]) |
44 |
| - g.safe_remove_nodes(match.get_nodes()) |
| 23 | + pattern2 = \ |
| 24 | + OpTypePattern('Identity', name='output', inputs=[ |
| 25 | + OpTypePattern('Identity', name='input2', inputs=[ |
| 26 | + OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]) |
| 27 | + ]) |
| 28 | + ]) |
| 29 | + |
| 30 | + pattern_list = [pattern1, pattern2] |
| 31 | + for pattern in pattern_list: |
| 32 | + matcher = GraphMatcher(pattern) |
| 33 | + match_results = list(matcher.match_ops(ops)) |
| 34 | + for match in match_results: |
| 35 | + output = match.get_op('output') |
| 36 | + if output.type == 'Add': |
| 37 | + # pattern 1 |
| 38 | + mean = output.inputs[1].get_tensor_value() |
| 39 | + else: |
| 40 | + # pattern 2 |
| 41 | + mean = 0.0 |
| 42 | + dtype = g.get_dtype(output.output[0]) |
| 43 | + op_name = utils.make_name("RandomNormal") |
| 44 | + out_name = utils.port_name(op_name) |
| 45 | + |
| 46 | + rn_op = match.get_op('input1') |
| 47 | + seed = rn_op.get_attr('seed2').i |
| 48 | + |
| 49 | + if rn_op.inputs[0].type == "Shape": |
| 50 | + shape_node = rn_op.inputs[0] |
| 51 | + new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, |
| 52 | + attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": float(seed)}) |
| 53 | + else: |
| 54 | + shape = g.get_shape(output.output[0]) |
| 55 | + new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, |
| 56 | + attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed}) |
| 57 | + |
| 58 | + g.replace_all_inputs(ops, output.output[0], new_node.output[0]) |
| 59 | + g.safe_remove_nodes(match.get_nodes()) |
45 | 60 | return ops
|
0 commit comments