This document describes the "native" C++ inference API available for Gematria models based on the TokenGraphBuilderModel class. This is provided in addition to the Python and command-line APIs available to all Gematria models.
To use the C++ API, you need to add the
inference library as a
dependency. You will also need one or more .tflite file with a trained GRANITE
model, or bundle it with the binary in some other way.
The API has a single class GraphBuilderModelInference. The general workflow
when using the API is:
- Load the TensorFlow Lite model as a
FlatBufferModelobject using the factory methods oftflite::FlatBufferModel. - Create a
GraphBuilderModelInferenceobject using the TensorFlow Lite model. The model contains all the necessary parameters to create the inference object. - For each batch, add basic blocks to the batch using
GraphBuilderModelInference::AddBasicBlockToBatch(), then compute the predictions for the basic blocks in the batch usingGraphBuilderModelInference::RunInference(). - To process more batches: call
GraphBuilderModelInference::Reset()and repeat from step 3.
The C++ API uses the TensorFlow Lite
framework. It wraps the TensorFlow Lite classes and functions and provides a
convenient API based on the
BasicBlock data structures used by
the rest of Gematria code.
By default, the Gematria inference library links just the necessary TensorFlow Lite libraries. You may have to add additional dependencies to use GPU processing when available.
A .tflite file contains a TensorFlow Lite computation graph, and the files are
typically created from a trained TensorFlow model, i.e. a TensorFlow graph and
variable values.
The conversion of Gematria models to the .tflite format is typically done in
two steps:
-
Export a frozen
tensorflow.GraphDeffrom a model and a checkpoint with weights. This can be done by passing--gematria_action=export_graph_defand--gematria_checkpoint_fileto the model binary. The command-line flags specifying the hyper-parameters of the model must be set the same way as they were set during the training of the model.With a checkpoint file specified, the export will replace all Variable nodes in the TensorFlow graph with constant nodes containing the values of the variables from the checkpoint.
Example:
bazel run -c opt \ //gematria/granite:token_graph_builder_model_main \ -- \ --logtostderr \ --gematria_action=export_graph_def \ --gematria_edge_embedding_size=256 \ --gematria_edge_update_layers=256,256 \ --gematria_global_embedding_size=256 \ --gematria_global_update_layers=256,256 \ --gematria_graph_module_layer_normalization \ --gematria_graph_module_residual_connections \ --gematria_node_embedding_size=256 \ --gematria_node_update_layers=256,256 \ --gematria_num_message_passing_iterations=4 \ --nogematria_readout_input_layer_normalization \ --gematria_readout_layers='' \ --gematria_readout_residual_connections \ --gematria_seq2seq \ --gematria_task_names=ivb \ --gematria_task_names=hsw \ --gematria_task_names=skx \ --nogematria_task_readout_input_layer_normalization \ --gematria_task_readout_layers= \ --gematria_task_readout_residual_connections \ --gematria_use_sent_edges \ --nogematria_use_seq2seq_loss \ --gematria_checkpoint_file=/tmp/model.ckpt-1 \ --gematria_graph_def_file=/tmp/gnn_frozen_graph.pbtxt -
Convert the model from the
GraphDefformat to the.tfliteformat. This is done using thetflite_converttool. For convenience and to ensure that models are exported with the inputs and outputs always in the same order, we provide a conversion script.Example:
gematria/granite/convert_gb_token_model_to_tflite.sh \ --gematria_input_graphdef /tmp/gnn_frozen_graph.pbtxt \ --gematria_output_tflite /tmp/gnn.tflite