@@ -33,8 +33,7 @@ public static NDArray load(Stream stream)
3333 if ( ! parseReader ( reader , out bytes , out type , out shape ) )
3434 throw new FormatException ( ) ;
3535
36- Array array = Arrays . Create ( type , shape . Aggregate ( ( dims , dim ) => dims * dim ) ) ;
37-
36+ Array array = Arrays . Create ( type , shape . Aggregate ( 1 , ( dims , dim ) => dims * dim ) ) ;
3837 var result = new NDArray ( readValueMatrix ( reader , array , bytes , type , shape ) ) ;
3938 return result . reshape ( shape ) ;
4039 }
@@ -165,6 +164,10 @@ public static Array LoadMatrix(Stream stream)
165164 int [ ] shape ;
166165 if ( ! parseReader ( reader , out bytes , out type , out shape ) )
167166 throw new FormatException ( ) ;
167+
168+ // Read scalar as a single element array
169+ if ( shape . Length == 0 )
170+ shape = new int [ ] { 1 } ;
168171
169172 Array matrix = Arrays . Create ( type , shape ) ;
170173
@@ -188,6 +191,10 @@ public static Array LoadJagged(Stream stream, bool trim = true)
188191 int [ ] shape ;
189192 if ( ! parseReader ( reader , out bytes , out type , out shape ) )
190193 throw new FormatException ( ) ;
194+
195+ // Read scalar as a single element array
196+ if ( shape . Length == 0 )
197+ shape = new int [ ] { 1 } ;
191198
192199 Array matrix = Arrays . Create ( type , shape ) ;
193200
@@ -357,7 +364,7 @@ private static bool parseReader(BinaryReader reader, out int bytes, out Type t,
357364
358365 mark = "'shape': (" ;
359366 s = header . IndexOf ( mark ) + mark . Length ;
360- e = header . IndexOf ( ")" , s + 1 ) ;
367+ e = header . IndexOf ( ")" , s ) ;
361368 shape = header . Substring ( s , e - s ) . Split ( ',' ) . Where ( v => ! String . IsNullOrEmpty ( v ) ) . Select ( Int32 . Parse ) . ToArray ( ) ;
362369
363370 return true ;
0 commit comments