@@ -75,8 +75,8 @@ private static void ParseGraph(byte[] graphBytes, OnnxGraphInfo graphInfo)
7575 break ;
7676 case 5 : // initializer (weights)
7777 var initBytes = reader . ReadBytes ( ) ;
78- var ( name , weights ) = ParseTensor ( initBytes . ToByteArray ( ) ) ;
79- graphInfo . Initializers [ name ] = weights ;
78+ var tensor = ParseTensor ( initBytes . ToByteArray ( ) ) ;
79+ graphInfo . Initializers [ tensor . Name ] = tensor ;
8080 break ;
8181 case 11 : // input
8282 var inputBytes = reader . ReadBytes ( ) ;
@@ -128,13 +128,15 @@ private static OnnxNode ParseNode(byte[] nodeBytes)
128128 return node ;
129129 }
130130
131- private static ( string name , float [ ] weights ) ParseTensor ( byte [ ] tensorBytes )
131+ private static OnnxTensor ParseTensor ( byte [ ] tensorBytes )
132132 {
133133 using var stream = new MemoryStream ( tensorBytes ) ;
134134 using var reader = new CodedInputStream ( stream ) ;
135135
136136 string name = string . Empty ;
137137 float [ ] weights = Array . Empty < float > ( ) ;
138+ int dataType = - 1 ; // ONNX TensorProto.DataType: 1 = FLOAT, 11 = DOUBLE, etc.
139+ var dims = new List < long > ( ) ;
138140
139141 while ( ! reader . IsAtEnd )
140142 {
@@ -146,17 +148,44 @@ private static (string name, float[] weights) ParseTensor(byte[] tensorBytes)
146148 case 3 : // name
147149 name = reader . ReadString ( ) ;
148150 break ;
151+ case 5 : // data_type
152+ dataType = reader . ReadInt32 ( ) ;
153+ break ;
154+ case 7 : // dims (repeated)
155+ dims . Add ( reader . ReadInt64 ( ) ) ;
156+ break ;
149157 case 9 : // raw_data
150158 var rawBytes = reader . ReadBytes ( ) . ToByteArray ( ) ;
151- weights = BytesToFloatArray ( rawBytes ) ;
159+ // Validate data type before conversion
160+ if ( dataType == 1 ) // FLOAT (32-bit)
161+ {
162+ weights = BytesToFloatArray ( rawBytes ) ;
163+ }
164+ else if ( dataType == - 1 )
165+ {
166+ // data_type field not yet encountered, assume float for backward compatibility
167+ weights = BytesToFloatArray ( rawBytes ) ;
168+ }
169+ else
170+ {
171+ throw new NotSupportedException (
172+ $ "Tensor '{ name } ' has unsupported data type { dataType } . " +
173+ $ "Only FLOAT (type 1) tensors are supported for ONNX→CoreML conversion. " +
174+ $ "Common types: 1=FLOAT, 11=DOUBLE, 2=UINT8, 3=INT8, 6=INT32, 7=INT64.") ;
175+ }
152176 break ;
153177 default :
154178 reader . SkipLastField ( ) ;
155179 break ;
156180 }
157181 }
158182
159- return ( name , weights ) ;
183+ return new OnnxTensor
184+ {
185+ Name = name ,
186+ Data = weights ,
187+ Shape = dims . Select ( d => ( int ) d ) . ToArray ( )
188+ } ;
160189 }
161190
162191 private static OnnxValueInfo ParseValueInfo ( byte [ ] valueInfoBytes )
@@ -191,7 +220,7 @@ private static OnnxValueInfo ParseValueInfo(byte[] valueInfoBytes)
191220
192221 private static int [ ] ParseTypeProto ( byte [ ] typeBytes )
193222 {
194- // Simplified: extract shape dimensions from tensor type
223+ // Parse ONNX TypeProto structure: TypeProto → tensor_type → shape → repeated dim → dim_value
195224 var shape = new List < int > ( ) ;
196225
197226 using var stream = new MemoryStream ( typeBytes ) ;
@@ -200,10 +229,12 @@ private static int[] ParseTypeProto(byte[] typeBytes)
200229 while ( ! reader . IsAtEnd )
201230 {
202231 var tag = reader . ReadTag ( ) ;
203- if ( WireFormat . GetTagWireType ( tag ) == WireFormat . WireType . Varint )
232+ var fieldNumber = WireFormat . GetTagFieldNumber ( tag ) ;
233+
234+ if ( fieldNumber == 1 ) // tensor_type (LengthDelimited)
204235 {
205- var dim = ( int ) reader . ReadInt64 ( ) ;
206- if ( dim > 0 ) shape . Add ( dim ) ;
236+ var tensorTypeBytes = reader . ReadBytes ( ) . ToByteArray ( ) ;
237+ shape = ParseTensorTypeProto ( tensorTypeBytes ) ;
207238 }
208239 else
209240 {
@@ -214,6 +245,88 @@ private static int[] ParseTypeProto(byte[] typeBytes)
214245 return shape . ToArray ( ) ;
215246 }
216247
248+ private static List < int > ParseTensorTypeProto ( byte [ ] tensorTypeBytes )
249+ {
250+ // Parse TensorTypeProto: field 1 = elem_type (skip), field 2 = shape
251+ var shape = new List < int > ( ) ;
252+
253+ using var stream = new MemoryStream ( tensorTypeBytes ) ;
254+ using var reader = new CodedInputStream ( stream ) ;
255+
256+ while ( ! reader . IsAtEnd )
257+ {
258+ var tag = reader . ReadTag ( ) ;
259+ var fieldNumber = WireFormat . GetTagFieldNumber ( tag ) ;
260+
261+ if ( fieldNumber == 2 ) // shape (LengthDelimited)
262+ {
263+ var shapeBytes = reader . ReadBytes ( ) . ToByteArray ( ) ;
264+ shape = ParseTensorShapeProto ( shapeBytes ) ;
265+ }
266+ else
267+ {
268+ reader . SkipLastField ( ) ; // Skip elem_type and unknown fields
269+ }
270+ }
271+
272+ return shape ;
273+ }
274+
275+ private static List < int > ParseTensorShapeProto ( byte [ ] shapeBytes )
276+ {
277+ // Parse TensorShapeProto: repeated field 1 = dim
278+ var dims = new List < int > ( ) ;
279+
280+ using var stream = new MemoryStream ( shapeBytes ) ;
281+ using var reader = new CodedInputStream ( stream ) ;
282+
283+ while ( ! reader . IsAtEnd )
284+ {
285+ var tag = reader . ReadTag ( ) ;
286+ var fieldNumber = WireFormat . GetTagFieldNumber ( tag ) ;
287+
288+ if ( fieldNumber == 1 ) // dim (LengthDelimited, repeated)
289+ {
290+ var dimBytes = reader . ReadBytes ( ) . ToByteArray ( ) ;
291+ var dimValue = ParseDimensionProto ( dimBytes ) ;
292+ if ( dimValue > 0 )
293+ {
294+ dims . Add ( dimValue ) ;
295+ }
296+ }
297+ else
298+ {
299+ reader . SkipLastField ( ) ;
300+ }
301+ }
302+
303+ return dims ;
304+ }
305+
306+ private static int ParseDimensionProto ( byte [ ] dimBytes )
307+ {
308+ // Parse DimensionProto: field 1 = dim_value (Varint)
309+ using var stream = new MemoryStream ( dimBytes ) ;
310+ using var reader = new CodedInputStream ( stream ) ;
311+
312+ while ( ! reader . IsAtEnd )
313+ {
314+ var tag = reader . ReadTag ( ) ;
315+ var fieldNumber = WireFormat . GetTagFieldNumber ( tag ) ;
316+
317+ if ( fieldNumber == 1 ) // dim_value
318+ {
319+ return ( int ) reader . ReadInt64 ( ) ;
320+ }
321+ else
322+ {
323+ reader . SkipLastField ( ) ; // Skip dim_param and unknown fields
324+ }
325+ }
326+
327+ return 0 ;
328+ }
329+
217330 private static float [ ] BytesToFloatArray ( byte [ ] bytes )
218331 {
219332 var floats = new float [ bytes . Length / 4 ] ;
@@ -285,7 +398,7 @@ private static CoreMLNeuralNetwork ConvertNeuralNetwork(OnnxGraphInfo onnxGraph,
285398 return network ;
286399 }
287400
288- private static CoreMLLayer ? ConvertOperatorToLayer ( OnnxNode op , Dictionary < string , float [ ] > initializers , int layerIndex )
401+ private static CoreMLLayer ? ConvertOperatorToLayer ( OnnxNode op , Dictionary < string , OnnxTensor > initializers , int layerIndex )
289402 {
290403 var layer = new CoreMLLayer
291404 {
@@ -303,18 +416,31 @@ private static CoreMLNeuralNetwork ConvertNeuralNetwork(OnnxGraphInfo onnxGraph,
303416
304417 // Extract weights from initializers
305418 var weightsKey = op . Inputs . Count > 1 ? op . Inputs [ 1 ] : null ;
306- if ( weightsKey != null && initializers . TryGetValue ( weightsKey , out var weights ) )
419+ if ( weightsKey != null && initializers . TryGetValue ( weightsKey , out var weightsTensor ) )
307420 {
308- layer . Weights = weights ;
309- layer . InputSize = weights . Length / ( weights . Length > 0 ? ( int ) Math . Sqrt ( weights . Length ) : 1 ) ;
310- layer . OutputSize = ( int ) Math . Sqrt ( weights . Length ) ;
421+ layer . Weights = weightsTensor . Data ;
422+
423+ // Use actual tensor shape instead of sqrt approximation
424+ // ONNX weight matrices for MatMul/Gemm are typically [out_dim, in_dim]
425+ if ( weightsTensor . Shape != null && weightsTensor . Shape . Length == 2 )
426+ {
427+ layer . OutputSize = weightsTensor . Shape [ 0 ] ;
428+ layer . InputSize = weightsTensor . Shape [ 1 ] ;
429+ }
430+ else if ( weightsTensor . Data . Length > 0 )
431+ {
432+ // Fallback for 1D or missing shape: infer square matrix (legacy behavior)
433+ var sqrtLen = ( int ) Math . Sqrt ( weightsTensor . Data . Length ) ;
434+ layer . InputSize = sqrtLen ;
435+ layer . OutputSize = sqrtLen ;
436+ }
311437 }
312438
313439 // Extract bias if present
314440 var biasKey = op . Inputs . Count > 2 ? op . Inputs [ 2 ] : null ;
315- if ( biasKey != null && initializers . TryGetValue ( biasKey , out var bias ) )
441+ if ( biasKey != null && initializers . TryGetValue ( biasKey , out var biasTensor ) )
316442 {
317- layer . Bias = bias ;
443+ layer . Bias = biasTensor . Data ;
318444 layer . HasBias = true ;
319445 }
320446 break ;
@@ -349,7 +475,7 @@ internal class OnnxGraphInfo
349475{
350476 public string Name { get ; set ; } = string . Empty ;
351477 public List < OnnxNode > Operations { get ; set ; } = new ( ) ;
352- public Dictionary < string , float [ ] > Initializers { get ; set ; } = new ( ) ;
478+ public Dictionary < string , OnnxTensor > Initializers { get ; set ; } = new ( ) ;
353479 public List < OnnxValueInfo > Inputs { get ; set ; } = new ( ) ;
354480 public List < OnnxValueInfo > Outputs { get ; set ; } = new ( ) ;
355481}
@@ -373,3 +499,13 @@ internal class OnnxValueInfo
373499 public string Name { get ; set ; } = string . Empty ;
374500 public int [ ] Shape { get ; set ; } = Array . Empty < int > ( ) ;
375501}
502+
503+ /// <summary>
504+ /// ONNX tensor with data and shape information.
505+ /// </summary>
506+ internal class OnnxTensor
507+ {
508+ public string Name { get ; set ; } = string . Empty ;
509+ public float [ ] Data { get ; set ; } = Array . Empty < float > ( ) ;
510+ public int [ ] Shape { get ; set ; } = Array . Empty < int > ( ) ;
511+ }
0 commit comments