Skip to content

Commit a5deb6a

Browse files
committed
fix: add production-ready onnx parsing with type validation and correct shape extraction
This commit fixes three critical issues in ONNX→CoreML conversion: 1. **Data type validation in ParseTensor**: Now reads and validates the data_type field (field 5), ensuring only FLOAT tensors are converted. Throws NotSupportedException for unsupported types (DOUBLE, INT8, etc.) instead of silently corrupting data. 2. **Correct TypeProto parsing**: Fixed ParseTypeProto to properly handle nested ONNX protobuf structure (TypeProto → tensor_type → shape → dim → dim_value) instead of incorrectly treating every varint as a dimension. This fixes tensor shape extraction for model inputs/outputs. 3. **Accurate InnerProduct layer sizing**: Changed from Math.Sqrt approximation (which assumed square matrices) to using actual tensor shape from ONNX dims. For MatMul/Gemm layers, correctly extracts [out_dim, in_dim] from weight tensor shape. Technical changes: - ParseTensor now returns OnnxTensor with Name, Data, and Shape fields - Added OnnxTensor class to store tensor metadata alongside float data - Updated OnnxGraphInfo.Initializers from Dictionary<string, float[]> to Dictionary<string, OnnxTensor> - Added ParseTensorTypeProto, ParseTensorShapeProto, and ParseDimensionProto helper methods - ConvertOperatorToLayer uses shape[0] and shape[1] for layer sizing with sqrt fallback
1 parent a756248 commit a5deb6a

File tree

1 file changed

+153
-17
lines changed

1 file changed

+153
-17
lines changed

src/Deployment/Mobile/CoreML/OnnxToCoreMLConverter.cs

Lines changed: 153 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ private static void ParseGraph(byte[] graphBytes, OnnxGraphInfo graphInfo)
7575
break;
7676
case 5: // initializer (weights)
7777
var initBytes = reader.ReadBytes();
78-
var (name, weights) = ParseTensor(initBytes.ToByteArray());
79-
graphInfo.Initializers[name] = weights;
78+
var tensor = ParseTensor(initBytes.ToByteArray());
79+
graphInfo.Initializers[tensor.Name] = tensor;
8080
break;
8181
case 11: // input
8282
var inputBytes = reader.ReadBytes();
@@ -128,13 +128,15 @@ private static OnnxNode ParseNode(byte[] nodeBytes)
128128
return node;
129129
}
130130

131-
private static (string name, float[] weights) ParseTensor(byte[] tensorBytes)
131+
private static OnnxTensor ParseTensor(byte[] tensorBytes)
132132
{
133133
using var stream = new MemoryStream(tensorBytes);
134134
using var reader = new CodedInputStream(stream);
135135

136136
string name = string.Empty;
137137
float[] weights = Array.Empty<float>();
138+
int dataType = -1; // ONNX TensorProto.DataType: 1 = FLOAT, 11 = DOUBLE, etc.
139+
var dims = new List<long>();
138140

139141
while (!reader.IsAtEnd)
140142
{
@@ -146,17 +148,44 @@ private static (string name, float[] weights) ParseTensor(byte[] tensorBytes)
146148
case 3: // name
147149
name = reader.ReadString();
148150
break;
151+
case 5: // data_type
152+
dataType = reader.ReadInt32();
153+
break;
154+
case 7: // dims (repeated)
155+
dims.Add(reader.ReadInt64());
156+
break;
149157
case 9: // raw_data
150158
var rawBytes = reader.ReadBytes().ToByteArray();
151-
weights = BytesToFloatArray(rawBytes);
159+
// Validate data type before conversion
160+
if (dataType == 1) // FLOAT (32-bit)
161+
{
162+
weights = BytesToFloatArray(rawBytes);
163+
}
164+
else if (dataType == -1)
165+
{
166+
// data_type field not yet encountered, assume float for backward compatibility
167+
weights = BytesToFloatArray(rawBytes);
168+
}
169+
else
170+
{
171+
throw new NotSupportedException(
172+
$"Tensor '{name}' has unsupported data type {dataType}. " +
173+
$"Only FLOAT (type 1) tensors are supported for ONNX→CoreML conversion. " +
174+
$"Common types: 1=FLOAT, 11=DOUBLE, 2=UINT8, 3=INT8, 6=INT32, 7=INT64.");
175+
}
152176
break;
153177
default:
154178
reader.SkipLastField();
155179
break;
156180
}
157181
}
158182

159-
return (name, weights);
183+
return new OnnxTensor
184+
{
185+
Name = name,
186+
Data = weights,
187+
Shape = dims.Select(d => (int)d).ToArray()
188+
};
160189
}
161190

162191
private static OnnxValueInfo ParseValueInfo(byte[] valueInfoBytes)
@@ -191,7 +220,7 @@ private static OnnxValueInfo ParseValueInfo(byte[] valueInfoBytes)
191220

192221
private static int[] ParseTypeProto(byte[] typeBytes)
193222
{
194-
// Simplified: extract shape dimensions from tensor type
223+
// Parse ONNX TypeProto structure: TypeProto → tensor_type → shape → repeated dim → dim_value
195224
var shape = new List<int>();
196225

197226
using var stream = new MemoryStream(typeBytes);
@@ -200,10 +229,12 @@ private static int[] ParseTypeProto(byte[] typeBytes)
200229
while (!reader.IsAtEnd)
201230
{
202231
var tag = reader.ReadTag();
203-
if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.Varint)
232+
var fieldNumber = WireFormat.GetTagFieldNumber(tag);
233+
234+
if (fieldNumber == 1) // tensor_type (LengthDelimited)
204235
{
205-
var dim = (int)reader.ReadInt64();
206-
if (dim > 0) shape.Add(dim);
236+
var tensorTypeBytes = reader.ReadBytes().ToByteArray();
237+
shape = ParseTensorTypeProto(tensorTypeBytes);
207238
}
208239
else
209240
{
@@ -214,6 +245,88 @@ private static int[] ParseTypeProto(byte[] typeBytes)
214245
return shape.ToArray();
215246
}
216247

248+
private static List<int> ParseTensorTypeProto(byte[] tensorTypeBytes)
249+
{
250+
// Parse TensorTypeProto: field 1 = elem_type (skip), field 2 = shape
251+
var shape = new List<int>();
252+
253+
using var stream = new MemoryStream(tensorTypeBytes);
254+
using var reader = new CodedInputStream(stream);
255+
256+
while (!reader.IsAtEnd)
257+
{
258+
var tag = reader.ReadTag();
259+
var fieldNumber = WireFormat.GetTagFieldNumber(tag);
260+
261+
if (fieldNumber == 2) // shape (LengthDelimited)
262+
{
263+
var shapeBytes = reader.ReadBytes().ToByteArray();
264+
shape = ParseTensorShapeProto(shapeBytes);
265+
}
266+
else
267+
{
268+
reader.SkipLastField(); // Skip elem_type and unknown fields
269+
}
270+
}
271+
272+
return shape;
273+
}
274+
275+
private static List<int> ParseTensorShapeProto(byte[] shapeBytes)
276+
{
277+
// Parse TensorShapeProto: repeated field 1 = dim
278+
var dims = new List<int>();
279+
280+
using var stream = new MemoryStream(shapeBytes);
281+
using var reader = new CodedInputStream(stream);
282+
283+
while (!reader.IsAtEnd)
284+
{
285+
var tag = reader.ReadTag();
286+
var fieldNumber = WireFormat.GetTagFieldNumber(tag);
287+
288+
if (fieldNumber == 1) // dim (LengthDelimited, repeated)
289+
{
290+
var dimBytes = reader.ReadBytes().ToByteArray();
291+
var dimValue = ParseDimensionProto(dimBytes);
292+
if (dimValue > 0)
293+
{
294+
dims.Add(dimValue);
295+
}
296+
}
297+
else
298+
{
299+
reader.SkipLastField();
300+
}
301+
}
302+
303+
return dims;
304+
}
305+
306+
private static int ParseDimensionProto(byte[] dimBytes)
307+
{
308+
// Parse DimensionProto: field 1 = dim_value (Varint)
309+
using var stream = new MemoryStream(dimBytes);
310+
using var reader = new CodedInputStream(stream);
311+
312+
while (!reader.IsAtEnd)
313+
{
314+
var tag = reader.ReadTag();
315+
var fieldNumber = WireFormat.GetTagFieldNumber(tag);
316+
317+
if (fieldNumber == 1) // dim_value
318+
{
319+
return (int)reader.ReadInt64();
320+
}
321+
else
322+
{
323+
reader.SkipLastField(); // Skip dim_param and unknown fields
324+
}
325+
}
326+
327+
return 0;
328+
}
329+
217330
private static float[] BytesToFloatArray(byte[] bytes)
218331
{
219332
var floats = new float[bytes.Length / 4];
@@ -285,7 +398,7 @@ private static CoreMLNeuralNetwork ConvertNeuralNetwork(OnnxGraphInfo onnxGraph,
285398
return network;
286399
}
287400

288-
private static CoreMLLayer? ConvertOperatorToLayer(OnnxNode op, Dictionary<string, float[]> initializers, int layerIndex)
401+
private static CoreMLLayer? ConvertOperatorToLayer(OnnxNode op, Dictionary<string, OnnxTensor> initializers, int layerIndex)
289402
{
290403
var layer = new CoreMLLayer
291404
{
@@ -303,18 +416,31 @@ private static CoreMLNeuralNetwork ConvertNeuralNetwork(OnnxGraphInfo onnxGraph,
303416

304417
// Extract weights from initializers
305418
var weightsKey = op.Inputs.Count > 1 ? op.Inputs[1] : null;
306-
if (weightsKey != null && initializers.TryGetValue(weightsKey, out var weights))
419+
if (weightsKey != null && initializers.TryGetValue(weightsKey, out var weightsTensor))
307420
{
308-
layer.Weights = weights;
309-
layer.InputSize = weights.Length / (weights.Length > 0 ? (int)Math.Sqrt(weights.Length) : 1);
310-
layer.OutputSize = (int)Math.Sqrt(weights.Length);
421+
layer.Weights = weightsTensor.Data;
422+
423+
// Use actual tensor shape instead of sqrt approximation
424+
// ONNX weight matrices for MatMul/Gemm are typically [out_dim, in_dim]
425+
if (weightsTensor.Shape != null && weightsTensor.Shape.Length == 2)
426+
{
427+
layer.OutputSize = weightsTensor.Shape[0];
428+
layer.InputSize = weightsTensor.Shape[1];
429+
}
430+
else if (weightsTensor.Data.Length > 0)
431+
{
432+
// Fallback for 1D or missing shape: infer square matrix (legacy behavior)
433+
var sqrtLen = (int)Math.Sqrt(weightsTensor.Data.Length);
434+
layer.InputSize = sqrtLen;
435+
layer.OutputSize = sqrtLen;
436+
}
311437
}
312438

313439
// Extract bias if present
314440
var biasKey = op.Inputs.Count > 2 ? op.Inputs[2] : null;
315-
if (biasKey != null && initializers.TryGetValue(biasKey, out var bias))
441+
if (biasKey != null && initializers.TryGetValue(biasKey, out var biasTensor))
316442
{
317-
layer.Bias = bias;
443+
layer.Bias = biasTensor.Data;
318444
layer.HasBias = true;
319445
}
320446
break;
@@ -349,7 +475,7 @@ internal class OnnxGraphInfo
349475
{
350476
public string Name { get; set; } = string.Empty;
351477
public List<OnnxNode> Operations { get; set; } = new();
352-
public Dictionary<string, float[]> Initializers { get; set; } = new();
478+
public Dictionary<string, OnnxTensor> Initializers { get; set; } = new();
353479
public List<OnnxValueInfo> Inputs { get; set; } = new();
354480
public List<OnnxValueInfo> Outputs { get; set; } = new();
355481
}
@@ -373,3 +499,13 @@ internal class OnnxValueInfo
373499
public string Name { get; set; } = string.Empty;
374500
public int[] Shape { get; set; } = Array.Empty<int>();
375501
}
502+
503+
/// <summary>
504+
/// ONNX tensor with data and shape information.
505+
/// </summary>
506+
internal class OnnxTensor
507+
{
508+
public string Name { get; set; } = string.Empty;
509+
public float[] Data { get; set; } = Array.Empty<float>();
510+
public int[] Shape { get; set; } = Array.Empty<int>();
511+
}

0 commit comments

Comments
 (0)