Skip to content

Commit 54a7ea8

Browse files
committed
mnist_savedmodel
1 parent 0fa6911 commit 54a7ea8

File tree

6 files changed

+15
-6
lines changed

6 files changed

+15
-6
lines changed

examples/mnist_savedmodel.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ fn main() -> Result<(), Box<dyn Error>> {
3737

3838
// Load the saved model exported by regression_savedmodel.py.
3939
let mut graph = Graph::new();
40-
let session =
41-
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session;
42-
let op_x = graph.operation_by_name_required("serving_default_sequential_input")?;
43-
let op_predict = graph.operation_by_name_required("StatefulPartitionedCall")?;
40+
let bundle =
41+
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
42+
let session = &bundle.session;
43+
44+
let signature = bundle.meta_graph_def().get_signature("serving_default")?;
45+
let input_info = signature.get_input("input")?;
46+
let op_x = graph.operation_by_name_required(&input_info.name().name)?;
47+
let output_info = signature.get_output("output")?;
48+
let op_predict = graph.operation_by_name_required(&output_info.name().name)?;
4449

4550
// Train the model (e.g. for fine tuning).
4651
let mut args = SessionRunArgs::new();
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4.0663176e-06, 1.4199884e-07, 9.556003e-05, 0.00065914105, 2.260991e-07, 4.076631e-06, 2.5459945e-09, 0.99904054, 1.5654963e-05, 0.00018059688
1+
3.112342e-05, 8.721303e-08, 0.0005018024, 0.0003709061, 1.6482764e-08, 1.8595395e-06, 1.3620006e-09, 0.999046, 5.4331244e-06, 4.2815118e-05

examples/mnist_savedmodel/mnist_savedmodel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
model.fit(x_train, y_train, epochs=1)
3232

3333
# convert output type through softmax so that it can be interpreted as probability
34-
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax(name="output")])
34+
inputs = tf.keras.Input((28, 28), name="input", dtype=tf.float32)
35+
x = model(inputs)
36+
outputs = tf.keras.layers.Softmax(name="output")(x)
37+
38+
probability_model = tf.keras.Model(inputs=inputs, outputs=outputs)
3539

3640
# dump expected values to compare Rust's outputs
3741
with open("examples/mnist_savedmodel/expected_values.txt", "w") as f:
-478 Bytes
Binary file not shown.
-106 Bytes
Binary file not shown.
-20 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)