|
52 | 52 | id=1, |
53 | 53 | rank=2, |
54 | 54 | out_shapes=lambda shapes: [ |
55 | | - [shapes[0][0], 1, 1, shapes[0][1]], # W |
| 55 | + [shapes[0][0], 1, 1, shapes[0][1]] |
| 56 | + if len(shapes[0]) > 1 |
| 57 | + else [1, 1, 1, 1], # W |
56 | 58 | [1, 1, 1, shapes[-1][-1]], # B |
57 | 59 | ], |
58 | 60 | patch_data=lambda data: [data[0], data[1]], |
|
324 | 326 | "ConcatV2", |
325 | 327 | "Identity", |
326 | 328 | ] |
327 | | - ): "BasicLSTM", |
328 | | - repr([re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"]): "BasicLSTM", |
329 | | - repr(["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"]): "BasicLSTM", |
| 329 | + ): "BasicLSTMReshapeOut", |
| 330 | + repr( |
| 331 | + [re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"] |
| 332 | + ): "BasicLSTMReshapeOut", |
| 333 | + repr( |
| 334 | + ["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"] |
| 335 | + ): "BasicLSTMReshapeOut", |
| 336 | + repr(["Reshape", re.compile("^lstm_[a-z]*/"), "ConcatV2"]): "BasicLSTMConcatOut", |
330 | 337 | repr(["Sigmoid", "Mul"]): "Swish", |
331 | 338 | repr(["Mul", "Abs", "Mul", "Add"]): "LeakyRelu", |
332 | 339 | repr( |
@@ -546,9 +553,12 @@ def order_by(args, names): |
546 | 553 | "SquaredDifference": lambda nodes, inputs, tensors, _: sqr_diff( |
547 | 554 | nodes[-1].name, inputs[0], inputs[1] |
548 | 555 | ), |
549 | | - "BasicLSTM": lambda nodes, inputs, tensors, context: basic_lstm( |
| 556 | + "BasicLSTMReshapeOut": lambda nodes, inputs, tensors, context: basic_lstm( |
550 | 557 | nodes, inputs, tensors, context, find_type="Reshape" |
551 | 558 | ), |
| 559 | + "BasicLSTMConcatOut": lambda nodes, inputs, tensors, context: basic_lstm( |
| 560 | + nodes, inputs, tensors, context, find_type="ConcatV2" |
| 561 | + ), |
552 | 562 | "Swish": lambda nodes, inputs, tensors, _: Struct(op="Swish", input=inputs), |
553 | 563 | "LeakyRelu": lambda nodes, inputs, tensors, _: Struct(op="LeakyRelu", input=inputs), |
554 | 564 | # TODO:'Round' |
|
0 commit comments