Skip to content

Commit 9bc75ec

Browse files
committed
fix pylint
1 parent 1942e49 commit 9bc75ec

File tree

3 files changed

+21
-30
lines changed

3 files changed

+21
-30
lines changed

tests/make_models.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,10 @@
88
from __future__ import unicode_literals
99

1010
import os
11-
import unittest
12-
from collections import namedtuple
13-
14-
import graphviz as gv
15-
from onnx import TensorProto
16-
from onnx import helper
1711

12+
import numpy as np
1813
import tensorflow as tf
1914
from tensorflow.python.framework.graph_util import convert_variables_to_constants
20-
import numpy as np
21-
22-
import os
23-
2415

2516
# pylint: disable=missing-docstring
2617

@@ -29,12 +20,13 @@
2920
training_epochs = 100
3021

3122
# Training Data
32-
train_X = np.array(
23+
_train_x = np.array(
3324
[3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1])
34-
train_Y = np.array(
25+
_train_y = np.array(
3526
[1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3])
36-
test_X = np.array([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1])
37-
test_Y = np.array([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03])
27+
_test_x = np.array([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1])
28+
_test_y = np.array([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03])
29+
3830

3931
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
4032
"""Freezes the state of a session into a pruned computation graph."""
@@ -52,20 +44,21 @@ def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=T
5244
output_names, freeze_var_names)
5345
return frozen_graph
5446

47+
5548
def train(model_path):
56-
n_samples = train_X.shape[0]
49+
n_samples = _train_x.shape[0]
5750

5851
# tf Graph Input
59-
X = tf.placeholder(tf.float32, name="X")
60-
Y = tf.placeholder(tf.float32, name="Y")
52+
x = tf.placeholder(tf.float32, name="X")
53+
y = tf.placeholder(tf.float32, name="Y")
6154

6255
# Set model weights
63-
W = tf.Variable(np.random.randn(), name="W")
56+
w = tf.Variable(np.random.randn(), name="W")
6457
b = tf.Variable(np.random.randn(), name="b")
6558

66-
pred = tf.add(tf.multiply(X, W), b)
59+
pred = tf.add(tf.multiply(x, w), b)
6760
pred = tf.identity(pred, name="pred")
68-
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
61+
cost = tf.reduce_sum(tf.pow(pred - y, 2)) / (2 * n_samples)
6962

7063
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
7164
saver = tf.train.Saver()
@@ -75,11 +68,11 @@ def train(model_path):
7568
sess.run(tf.global_variables_initializer())
7669

7770
# Fit all training data
78-
for epoch in range(training_epochs):
79-
for (x, y) in zip(train_X, train_Y):
80-
sess.run(optimizer, feed_dict={X: x, Y: y})
81-
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
82-
testing_cost = sess.run(cost, feed_dict={X: test_X, Y: test_Y})
71+
for _ in range(training_epochs):
72+
for (ix, iy) in zip(_train_x, _train_y):
73+
sess.run(optimizer, feed_dict={x: ix, y: iy})
74+
training_cost = sess.run(cost, feed_dict={x: _train_x, y: _train_y})
75+
testing_cost = sess.run(cost, feed_dict={x: _test_x, y: _test_y})
8376
print("train_cost={}, test_cost={}, diff={}"
8477
.format(training_cost, testing_cost, abs(training_cost - testing_cost)))
8578

@@ -92,8 +85,7 @@ def train(model_path):
9285
tf.train.write_graph(frozen_graph, p, "frozen.pb", as_text=False)
9386

9487
p = os.path.abspath(os.path.join(model_path, "saved_model"))
95-
tf.saved_model.simple_save(sess, p, inputs={"X": X}, outputs={"pred": pred})
88+
tf.saved_model.simple_save(sess, p, inputs={"X": x}, outputs={"pred": pred})
9689

9790

9891
train("models/regression")
99-

tests/run_pretrained_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
2626
import tensorflow.contrib.rnn # pylint: disable=unused-import
2727
import yaml
28-
from tensorflow.core.framework import graph_pb2
2928

3029
import tf2onnx
3130
from tf2onnx import loader
@@ -227,7 +226,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
227226
graph_def, inputs, outputs = loader.from_saved_model(model_path, inputs, outputs)
228227
else:
229228
graph_def, inputs, outputs = loader.from_graphdef(model_path, inputs, outputs)
230-
229+
231230
# create the input data
232231
inputs = {}
233232
for k, v in self.input_names.items():

tf2onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_args():
5757
if not args.input and not args.outputs:
5858
raise ValueError("graphdef and checkpoint models need to provide inputs and outputs")
5959
if not any([args.graphdef, args.checkpoint, args.saved_model]):
60-
raise ValueError("need input as graphdef, checkpoint or saved_model")
60+
raise ValueError("need input as graphdef, checkpoint or saved_model")
6161
if args.inputs:
6262
args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs)
6363
if args.outputs:

0 commit comments

Comments
 (0)