Skip to content

Commit 3a9264b

Browse files
committed
code refactor according to review comments
1 parent 62fcc8c commit 3a9264b

File tree

4 files changed

+6
-34
lines changed

4 files changed

+6
-34
lines changed

tests/run_pretrained_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import tensorflow as tf
2424
from tensorflow.core.framework import graph_pb2
2525
from tensorflow.python.framework.graph_util import convert_variables_to_constants
26-
from tensorflow.contrib.rnn import GRUBlockCell # pylint: disable=unused-import
26+
# contrib ops are registered only when the module is imported, the following import statement is needed,
27+
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
28+
import tensorflow.contrib.rnn # pylint: disable=unused-import
2729
import yaml
2830
import PIL.Image
2931

tests/test_gru.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def test_dynamic_bigru(self):
312312
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
313313
initializer = init_ops.constant_initializer(0.5)
314314

315-
gru_list = []
316315
if True:
317316
# bigru, no scope
318317
cell1 = rnn.GRUCell(
@@ -326,7 +325,6 @@ def test_dynamic_bigru(self):
326325
cell2,
327326
x,
328327
dtype=tf.float32)
329-
gru_list.append(outputs)
330328

331329
_ = tf.identity(outputs, name="output")
332330
_ = tf.identity(cell_state, name="cell_state")
@@ -345,7 +343,6 @@ def test_dynamic_bigru_output_consumed_only(self):
345343
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
346344
initializer = init_ops.constant_initializer(0.5)
347345

348-
gru_list = []
349346
if True:
350347
# bigru, no scope
351348
cell1 = rnn.GRUCell(
@@ -359,7 +356,6 @@ def test_dynamic_bigru_output_consumed_only(self):
359356
cell2,
360357
x,
361358
dtype=tf.float32)
362-
gru_list.append(outputs)
363359

364360
_ = tf.identity(outputs, name="output")
365361

@@ -377,7 +373,6 @@ def test_dynamic_bigru_state_consumed_only(self):
377373
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
378374
initializer = init_ops.constant_initializer(0.5)
379375

380-
gru_list = []
381376
if True:
382377
# bigru, no scope
383378
cell1 = rnn.GRUCell(
@@ -386,12 +381,11 @@ def test_dynamic_bigru_state_consumed_only(self):
386381
cell2 = rnn.GRUCell(
387382
units,
388383
kernel_initializer=initializer)
389-
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
384+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
390385
cell1,
391386
cell2,
392387
x,
393388
dtype=tf.float32)
394-
gru_list.append(outputs)
395389

396390
_ = tf.identity(cell_state, name="cell_state")
397391

@@ -409,7 +403,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
409403
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
410404
initializer = init_ops.constant_initializer(0.5)
411405

412-
gru_list = []
413406
if True:
414407
# bigru, no scope
415408
cell = rnn.GRUCell(
@@ -420,7 +413,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
420413
cell,
421414
x,
422415
dtype=tf.float32)
423-
gru_list.append(outputs)
424416

425417
_ = tf.identity(outputs, name="output")
426418
_ = tf.identity(cell_state, name="cell_state")
@@ -438,7 +430,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
438430

439431
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
440432

441-
gru_list = []
442433
if True:
443434
# bigru, no scope
444435
cell = rnn.GRUCell(
@@ -448,7 +439,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
448439
cell,
449440
x,
450441
dtype=tf.float32)
451-
gru_list.append(outputs)
452442

453443
_ = tf.identity(outputs, name="output")
454444

@@ -465,17 +455,15 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
465455

466456
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
467457

468-
gru_list = []
469458
if True:
470459
# bigru, no scope
471460
cell = rnn.GRUCell(
472461
units)
473-
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
462+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
474463
cell,
475464
cell,
476465
x,
477466
dtype=tf.float32)
478-
gru_list.append(outputs)
479467

480468
_ = tf.identity(cell_state, name="cell_state")
481469

tests/test_grublock.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def test_dynamic_bigru(self):
290290

291291
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
292292

293-
gru_list = []
294293
if True:
295294
# bigru, no scope
296295
cell1 = rnn.GRUBlockCell(
@@ -302,7 +301,6 @@ def test_dynamic_bigru(self):
302301
cell2,
303302
x,
304303
dtype=tf.float32)
305-
gru_list.append(outputs)
306304

307305
_ = tf.identity(outputs, name="output")
308306
_ = tf.identity(cell_state, name="cell_state")
@@ -320,7 +318,6 @@ def test_dynamic_bigru_output_consumed_only(self):
320318

321319
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
322320

323-
gru_list = []
324321
if True:
325322
# bigru, no scope
326323
cell1 = rnn.GRUBlockCell(
@@ -332,7 +329,6 @@ def test_dynamic_bigru_output_consumed_only(self):
332329
cell2,
333330
x,
334331
dtype=tf.float32)
335-
gru_list.append(outputs)
336332

337333
_ = tf.identity(outputs, name="output")
338334

@@ -349,7 +345,6 @@ def test_dynamic_bigru_state_consumed_only(self):
349345

350346
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
351347

352-
gru_list = []
353348
if True:
354349
# bigru, no scope
355350
cell1 = rnn.GRUBlockCell(
@@ -361,7 +356,6 @@ def test_dynamic_bigru_state_consumed_only(self):
361356
cell2,
362357
x,
363358
dtype=tf.float32)
364-
gru_list.append(cell_state)
365359

366360
_ = tf.identity(cell_state, name="cell_state")
367361

@@ -378,7 +372,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
378372

379373
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
380374

381-
gru_list = []
382375
if True:
383376
# bigru, no scope
384377
cell = rnn.GRUBlockCell(
@@ -388,7 +381,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
388381
cell,
389382
x,
390383
dtype=tf.float32)
391-
gru_list.append(outputs)
392384

393385
_ = tf.identity(outputs, name="output")
394386
_ = tf.identity(cell_state, name="cell_state")
@@ -406,7 +398,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
406398

407399
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
408400

409-
gru_list = []
410401
if True:
411402
# bigru, no scope
412403
cell = rnn.GRUBlockCell(
@@ -416,7 +407,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
416407
cell,
417408
x,
418409
dtype=tf.float32)
419-
gru_list.append(outputs)
420410

421411
_ = tf.identity(outputs, name="output")
422412

@@ -433,7 +423,6 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
433423

434424
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
435425

436-
gru_list = []
437426
if True:
438427
# bigru, no scope
439428
cell = rnn.GRUBlockCell(
@@ -443,7 +432,6 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
443432
cell,
444433
x,
445434
dtype=tf.float32)
446-
gru_list.append(cell_state)
447435

448436
_ = tf.identity(cell_state, name="cell_state")
449437

tests/test_lstm.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,6 @@ def internal_test_dynamic_bilstm_with_parameters(self, state_is_tuple):
396396
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
397397
initializer = init_ops.constant_initializer(0.5)
398398

399-
lstm_list = []
400399
if True:
401400
# bilstm, no scope
402401
cell1 = rnn.LSTMCell(
@@ -412,7 +411,6 @@ def internal_test_dynamic_bilstm_with_parameters(self, state_is_tuple):
412411
cell2,
413412
x,
414413
dtype=tf.float32)
415-
lstm_list.append(outputs)
416414

417415
_ = tf.identity(outputs, name="output")
418416
_ = tf.identity(cell_state, name="cell_state")
@@ -431,7 +429,6 @@ def test_dynamic_bilstm_output_consumed_only(self, state_is_tuple=True):
431429
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
432430
initializer = init_ops.constant_initializer(0.5)
433431

434-
lstm_list = []
435432
if True:
436433
# bilstm, no scope
437434
cell1 = rnn.LSTMCell(
@@ -447,7 +444,6 @@ def test_dynamic_bilstm_output_consumed_only(self, state_is_tuple=True):
447444
cell2,
448445
x,
449446
dtype=tf.float32)
450-
lstm_list.append(outputs)
451447

452448
_ = tf.identity(outputs, name="output")
453449

@@ -465,7 +461,6 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
465461
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
466462
initializer = init_ops.constant_initializer(0.5)
467463

468-
lstm_list = []
469464
if True:
470465
# bilstm, no scope
471466
cell1 = rnn.LSTMCell(
@@ -476,12 +471,11 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
476471
units,
477472
initializer=initializer,
478473
state_is_tuple=state_is_tuple)
479-
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
474+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
480475
cell1,
481476
cell2,
482477
x,
483478
dtype=tf.float32)
484-
lstm_list.append(outputs)
485479

486480
_ = tf.identity(cell_state, name="cell_state")
487481

0 commit comments

Comments
 (0)