| 
 | 1 | +use std::error::Error;  | 
 | 2 | +use std::fs::File;  | 
 | 3 | +use std::io::Read;  | 
 | 4 | +use std::path::Path;  | 
 | 5 | +use std::result::Result;  | 
 | 6 | +use tensorflow::Code;  | 
 | 7 | +use tensorflow::Graph;  | 
 | 8 | +use tensorflow::ImportGraphDefOptions;  | 
 | 9 | +use tensorflow::Session;  | 
 | 10 | +use tensorflow::SessionOptions;  | 
 | 11 | +use tensorflow::SessionRunArgs;  | 
 | 12 | +use tensorflow::Status;  | 
 | 13 | +use tensorflow::Tensor;  | 
 | 14 | + | 
 | 15 | +use ndarray;  | 
 | 16 | + | 
 | 17 | +use image::io::Reader as ImageReader;  | 
 | 18 | +use image::GenericImageView;  | 
 | 19 | + | 
 | 20 | +fn main() -> Result<(), Box<dyn Error>> {  | 
 | 21 | +    let filename = "examples/zenn/mobilenetv3large.pb";  | 
 | 22 | +    if !Path::new(filename).exists() {  | 
 | 23 | +        return Err(Box::new(  | 
 | 24 | +            Status::new_set(  | 
 | 25 | +                Code::NotFound,  | 
 | 26 | +                &format!(  | 
 | 27 | +                    "Run 'python examples/zenn/zenn.py' to generate {} \  | 
 | 28 | +                     and try again.",  | 
 | 29 | +                    filename  | 
 | 30 | +                ),  | 
 | 31 | +            )  | 
 | 32 | +            .unwrap(),  | 
 | 33 | +        ));  | 
 | 34 | +    }  | 
 | 35 | + | 
 | 36 | +    // Create input variables for our addition  | 
 | 37 | +    let mut x = Tensor::new(&[1, 224, 224, 3]);  | 
 | 38 | +    let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;  | 
 | 39 | +    for (i, (_, _, pixel)) in img.pixels().enumerate() {  | 
 | 40 | +        x[3 * i] = pixel.0[0] as f32;  | 
 | 41 | +        x[3 * i + 1] = pixel.0[1] as f32;  | 
 | 42 | +        x[3 * i + 2] = pixel.0[2] as f32;  | 
 | 43 | +    }  | 
 | 44 | + | 
 | 45 | +    // Load the computation graph defined by addition.py.  | 
 | 46 | +    let mut graph = Graph::new();  | 
 | 47 | +    let mut proto = Vec::new();  | 
 | 48 | +    File::open(filename)?.read_to_end(&mut proto)?;  | 
 | 49 | +    graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;  | 
 | 50 | +    let session = Session::new(&SessionOptions::new(), &graph)?;  | 
 | 51 | + | 
 | 52 | +    // Run the graph.  | 
 | 53 | +    let mut args = SessionRunArgs::new();  | 
 | 54 | +    args.add_feed(&graph.operation_by_name_required("x")?, 0, &x);  | 
 | 55 | +    let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);  | 
 | 56 | +    session.run(&mut args)?;  | 
 | 57 | + | 
 | 58 | +    // Check our results.  | 
 | 59 | +    let output: Tensor<f32> = args.fetch(output)?;  | 
 | 60 | +    let res: ndarray::Array<f32, _> = output.into();  | 
 | 61 | +    println!("{:?}", res);  | 
 | 62 | + | 
 | 63 | +    Ok(())  | 
 | 64 | +}  | 
0 commit comments