Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 5393ebb

Browse files
authored
Drawing Classifer Compatible with TensorFlow V2 Behavior (#3028)
Rather than initialize TensorFlow variables to zero then assign the correct values, variables must be initialized to the correct values at the begining.
1 parent 82208b3 commit 5393ebb

File tree

1 file changed

+27
-90
lines changed

1 file changed

+27
-90
lines changed

src/python/turicreate/toolkits/drawing_classifier/_tf_drawing_classifier.py

Lines changed: 27 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import turicreate.toolkits._tf_utils as _utils
1313
import tensorflow.compat.v1 as _tf
1414

15+
# This toolkit is compatible with TensorFlow V2 behavior.
16+
# However, until all toolkits are compatible, we must call `disable_v2_behavior()`.
1517
_tf.disable_v2_behavior()
1618

1719

@@ -20,7 +22,6 @@ def __init__(self, net_params, batch_size, num_classes):
2022
"""
2123
Defines the TensorFlow model, loss, optimisation and accuracy. Then
2224
loads the weights into the model.
23-
2425
"""
2526
self.gpu_policy = _utils.TensorFlowGPUPolicy()
2627
self.gpu_policy.start()
@@ -54,40 +55,34 @@ def init_drawing_classifier_graph(self, net_params):
5455

5556
# Weights
5657
weights = {
57-
"drawing_conv0_weight": _tf.Variable(
58-
_tf.zeros([3, 3, 1, 16]), name="drawing_conv0_weight"
59-
),
60-
"drawing_conv1_weight": _tf.Variable(
61-
_tf.zeros([3, 3, 16, 32]), name="drawing_conv1_weight"
62-
),
63-
"drawing_conv2_weight": _tf.Variable(
64-
_tf.zeros([3, 3, 32, 64]), name="drawing_conv2_weight"
65-
),
66-
"drawing_dense0_weight": _tf.Variable(
67-
_tf.zeros([576, 128]), name="drawing_dense0_weight"
68-
),
69-
"drawing_dense1_weight": _tf.Variable(
70-
_tf.zeros([128, self.num_classes]), name="drawing_dense1_weight"
71-
),
58+
name: _tf.Variable(_utils.convert_conv2d_coreml_to_tf(net_params[name]), name=name)
59+
for name in ("drawing_conv0_weight",
60+
"drawing_conv1_weight",
61+
"drawing_conv2_weight")
7262
}
63+
weights["drawing_dense1_weight"] = _tf.Variable(
64+
_utils.convert_dense_coreml_to_tf(net_params["drawing_dense1_weight"]), name="drawing_dense1_weight"
65+
)
66+
"""
67+
To make output of CoreML pool3 (NCHW) compatible with TF (NHWC).
68+
Decompose FC weights to NCHW. Transpose to NHWC. Reshape back to FC.
69+
"""
70+
coreml_128_576 = net_params["drawing_dense0_weight"]
71+
coreml_128_576 = _np.reshape(coreml_128_576, (128, 64, 3, 3))
72+
coreml_128_576 = _np.transpose(coreml_128_576, (0, 2, 3, 1))
73+
coreml_128_576 = _np.reshape(coreml_128_576, (128, 576))
74+
weights["drawing_dense0_weight"] = _tf.Variable(
75+
_np.transpose(coreml_128_576, (1, 0)), name="drawing_dense0_weight"
76+
)
7377

7478
# Biases
7579
biases = {
76-
"drawing_conv0_bias": _tf.Variable(
77-
_tf.zeros([16]), name="drawing_conv0_bias"
78-
),
79-
"drawing_conv1_bias": _tf.Variable(
80-
_tf.zeros([32]), name="drawing_conv1_bias"
81-
),
82-
"drawing_conv2_bias": _tf.Variable(
83-
_tf.zeros([64]), name="drawing_conv2_bias"
84-
),
85-
"drawing_dense0_bias": _tf.Variable(
86-
_tf.zeros([128]), name="drawing_dense0_bias"
87-
),
88-
"drawing_dense1_bias": _tf.Variable(
89-
_tf.zeros([self.num_classes]), name="drawing_dense1_bias"
90-
),
80+
name: _tf.Variable(net_params[name], name=name)
81+
for name in ("drawing_conv0_bias",
82+
"drawing_conv1_bias",
83+
"drawing_conv2_bias",
84+
"drawing_dense0_bias",
85+
"drawing_dense1_bias")
9186
}
9287

9388
conv_1 = _tf.nn.conv2d(
@@ -119,23 +114,19 @@ def init_drawing_classifier_graph(self, net_params):
119114

120115
# Flatten the data to a 1-D vector for the fully connected layer
121116
fc1 = _tf.reshape(pool_3, (-1, 576))
122-
123117
fc1 = _tf.nn.xw_plus_b(
124118
fc1,
125119
weights=weights["drawing_dense0_weight"],
126120
biases=biases["drawing_dense0_bias"],
127121
)
128-
129122
fc1 = _tf.nn.relu(fc1)
130123

131124
out = _tf.nn.xw_plus_b(
132125
fc1,
133126
weights=weights["drawing_dense1_weight"],
134127
biases=biases["drawing_dense1_bias"],
135128
)
136-
softmax_out = _tf.nn.softmax(out)
137-
138-
self.predictions = softmax_out
129+
self.predictions = _tf.nn.softmax(out)
139130

140131
# Loss
141132
self.cost = _tf.losses.softmax_cross_entropy(
@@ -153,60 +144,6 @@ def init_drawing_classifier_graph(self, net_params):
153144
self.sess = _tf.Session()
154145
self.sess.run(_tf.global_variables_initializer())
155146

156-
# Assign the initialised weights from C++ to tensorflow
157-
layers = [
158-
"drawing_conv0_weight",
159-
"drawing_conv0_bias",
160-
"drawing_conv1_weight",
161-
"drawing_conv1_bias",
162-
"drawing_conv2_weight",
163-
"drawing_conv2_bias",
164-
"drawing_dense0_weight",
165-
"drawing_dense0_bias",
166-
"drawing_dense1_weight",
167-
"drawing_dense1_bias",
168-
]
169-
170-
for key in layers:
171-
if "bias" in key:
172-
self.sess.run(
173-
_tf.assign(
174-
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
175-
net_params[key],
176-
)
177-
)
178-
else:
179-
if "drawing_dense0_weight" in key:
180-
"""
181-
To make output of CoreML pool3 (NCHW) compatible with TF (NHWC).
182-
Decompose FC weights to NCHW. Transpose to NHWC. Reshape back to FC.
183-
"""
184-
coreml_128_576 = net_params[key]
185-
coreml_128_576 = _np.reshape(coreml_128_576, (128, 64, 3, 3))
186-
coreml_128_576 = _np.transpose(coreml_128_576, (0, 2, 3, 1))
187-
coreml_128_576 = _np.reshape(coreml_128_576, (128, 576))
188-
self.sess.run(
189-
_tf.assign(
190-
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
191-
_np.transpose(coreml_128_576, (1, 0)),
192-
)
193-
)
194-
elif "dense" in key:
195-
dense_weights = _utils.convert_dense_coreml_to_tf(net_params[key])
196-
self.sess.run(
197-
_tf.assign(
198-
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
199-
dense_weights,
200-
)
201-
)
202-
else:
203-
self.sess.run(
204-
_tf.assign(
205-
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
206-
_utils.convert_conv2d_coreml_to_tf(net_params[key]),
207-
)
208-
)
209-
210147
def __del__(self):
211148
self.sess.close()
212149
self.gpu_policy.stop()

0 commit comments

Comments
 (0)