Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 143 additions & 10 deletions caffe2/proto/torch.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

syntax = "proto2";

package onnx;
package torch;

// Overview
//
Expand All @@ -35,10 +35,11 @@ package onnx;
// by sharing our working version of ONNX.
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// To simplify framework compatibility, ONNX is defined using the subset of
// protobuf that is compatible with both protobuf v2 and v3. This means that we
// do not use any protobuf features that are only available in one of the two
// versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
Expand All @@ -47,7 +48,6 @@ package onnx;
// of key-value pairs, where order does not matter and duplicates
// are not allowed.


// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
Expand All @@ -60,8 +60,8 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with with 0x00000001,
// control.
// For the IR, we are using simple numbers starting with with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;

Expand All @@ -74,7 +74,10 @@ enum Version {
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION = 0x0000000000000003;
IR_VERSION_NEWEST = 0x0000000000000003;

// PYTORCH IR VERSION
IR_VERSION_NEWEST = 0x0000000100000003;
}

// Attributes
Expand Down Expand Up @@ -174,6 +177,12 @@ message NodeProto {

// A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6;

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Peter is making a change to store the schema (just a string) in the doc_string for now. We could add an additional scheme field here for that.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use an additional string field as well, to store the schema of a script module.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we can use op_type to store the schema, and store the node type as one annotation, with name "IR_NODE_TYPE"

// Additional annotations, attributes are defined in Schema
repeated AttributeProto annotation = 8;

// Node type, like PythonOp, etc, purely for PyTorch
optional string node_type = 51;
}

// Models
Expand All @@ -182,6 +191,9 @@ message NodeProto {
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto.
//
// Model ==> Caffe2 MetaNetDef
// ==> PyTorch Module
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
Expand Down Expand Up @@ -222,10 +234,28 @@ message ModelProto {
optional string doc_string = 6;

// The parameterized graph that is evaluated to execute the model.
// The main graph, in single graph case, it is ONNX compatible.
optional GraphProto graph = 7;

// The remaining nets in MetaNetDef.
// Submodules and methods in PyTorch.
repeated GraphProto methods = 15;

// Named metadata values; keys should be distinct.
// Many meta data in MetaNetDef and preditor are piggy backed here.
// 1) project
// 2) model_class
// 3) internal_version
// 4) predictor_type
// 5) predictor_id
// 6) execute_plan
// 7) applicationSpecificInfo (another string map, need to verify it has no duplicate.)
// 8) engine
// 9) publish time
repeated StringStringEntryProto metadata_props = 14;

// Model name
optional string name = 16;
};

// StringStringEntryProto follows the pattern for cross-proto-version maps.
Expand All @@ -241,6 +271,8 @@ message StringStringEntryProto {
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
// Graph ==> NetDef in Caffe2
// ==> Submodule/Method in PyTorch
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
Expand All @@ -264,6 +296,9 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;

// Additional annotations.
repeated AttributeProto annotation = 14;

// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3;
// repeated string output = 4;
Expand Down Expand Up @@ -298,10 +333,15 @@ message TensorProto {
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Future extensions go here.

// Special data type, real type information is stored in ValueInfoProto.
// If data_type is SPECIAL, raw_data should be used.
SPECIAL = 51;
}

// The shape of the tensor.
repeated int64 dims = 1;
repeated int64 strides = 14;

// The data type of the tensor.
optional DataType data_type = 2;
Expand All @@ -313,6 +353,7 @@ message TensorProto {
optional int64 begin = 1;
optional int64 end = 2;
}
// Used as offset in the external shared data.
optional Segment segment = 3;

// Tensor content must be organized in row-major order.
Expand Down Expand Up @@ -383,6 +424,25 @@ message TensorProto {
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's two additional things that we store for Parameters that don't have a place to go: is_buffer and requires_grad. Both are booleans.

// External data by file name
optional string external_data = 13;

// If two tensors represent the same weights/content, use alias.
// Must exist a TensorProto named alias in the initializer list.
// To avoid the duplicate tensor in attribute, such as value in Constant node.
// This is useful, if everything is stored just in the proto.
optional string alias = 16;

// Additional annotations.
repeated AttributeProto annotation = 17;

// Device info
optional DeviceOption device_detail = 51;

// For PyTorch serialized tensor.
optional int64 require_gradient = 52;
optional int64 is_bool = 53;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/is_bool/is_buffer/

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops ><

}

// Defines a tensor shape. A dimension can be either an integer value
Expand All @@ -401,7 +461,10 @@ message TensorShapeProto {
// for pre-defined dimension denotations.
optional string denotation = 3;
};
// To represent a scalar, using one element dim, and dim[0].dim_value = 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be the same as a one dimensional tensor with size (0,)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, let's make sure, if it is scalar, dim is absent.

repeated Dimension dim = 1;

repeated Dimension stride = 51;
}

// Types
Expand All @@ -416,11 +479,39 @@ message TypeProto {
optional TensorShapeProto shape = 2;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're currently storing strides in Dimension of TensorShapeProto. For example, if a tensor is of shape (2, 2) and strides (1, 2), we store 2, 2, 1, 2 in Dimension. Would be nice to store these separately.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I will also add stride field into TensorShapeProto.

}

// Sequence type: List, Tuple
message Sequence {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to have separate messages for List and Tuple and get rid of Sequence because we know Lists have one elem_type and Tuple have multiple. Is there a specific reason we need to support undefined sequence type?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, because we try to make it onnx compatible. In onnx(-ml) it only supports sequence with the same data type. Let's make it backward compatible. :-)

// elem_type and elem_type_list cannot appear together.
// If all the element types are the same, we use elem_type,
// otherwise, we specify the type of each element in elem_type_list.
optional TypeProto elem_type = 1;
repeated TypeProto elem_type_list = 51;
enum SequenceType {
UNDEFINED = 0;
LIST = 1;
TUPLE = 2;
}
optional SequenceType sequence_type = 52;
}

// Map<K, V>, (not necessary at this moment)
message Map {
optional TensorProto.DataType key_type = 1;
optional TypeProto value_type = 2;
}

// Special type of blobs, based on the type_name, we can choose the right
// serializer and deserialzier.
message SpecialBlob {
optional string type_name = 1;
}

oneof value {
// The type of a tensor.
Tensor tensor_type = 1;

Sequence sequence_type = 4;
Map map_type = 5;
SpecialBlob = 51;
}

// An optional denotation can be used to denote the whole
Expand All @@ -444,3 +535,45 @@ message OperatorSetIdProto {
// This field MUST be present in this version of the IR.
optional int64 version = 2;
}

// DeviceType that Caffe2 currently supports.
// Note: if you add a device type, make sure you add the corresponding device
// line in the DeviceTypeName() function in caffe2/utils/proto_utils.cc
// and update ATen/core/DeviceType.h
enum DeviceType {
CPU = 0; // In default, we will use CPU.
CUDA = 1; // CUDA.
MKLDNN = 2; // Reserved for explicit MKLDNN
OPENGL = 3; // OpenGL
OPENCL = 4; // OpenCL
IDEEP = 5; // IDEEP.
HIP = 6; // AMD HIP
// Change the following number if you add more devices in the code.
COMPILE_TIME_MAX_DEVICE_TYPES = 7;
ONLY_FOR_TEST = 20901701; // This device type is only for test.
}

// Device-specific options. We do not distinguish DeviceOption protos for
// different DeviceTypes, so currently all devices share the same DeviceOption
// proto. Fields that are specific to a device type is ignored if the type does
// not match.
// Note: if you add fields to the DeviceOption, make sure you add the
// corresponding changes to IsSameDevice() function in utils/proto_utils.{h,cc}.
message DeviceOption {
// [general] Options that need to be carried out before running the execution.
// optional DeviceType device_type = 1 [ default = CPU ];
optional int32 device_type = 1 [ default = 0 ]; // 0 is CPU.
// [CUDA specific] the cuda gpu id.
optional int32 cuda_gpu_id = 2;
// [general] The random seed to start the device random number generator with.
optional uint32 random_seed = 3;
// [general] What node this op should execute on.
// Used for net transformation purposes. Must be empty at execution time.
optional string node_name = 4;
// [CPU and Linux specific] NUMA node id
optional int32 numa_node_id = 5 [default = -1];
// [general] Extra information passed, not used at execution time currently.
repeated string extra_info = 6;
// [HIP specific] the hip gpu id.
optional int32 hip_gpu_id = 7;
}