@@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
133133 ThrowHelper . ThrowArgument_ConcatenateTooFewTensors ( ) ;
134134
135135 if ( dimension < - 1 || dimension > tensors [ 0 ] . Rank )
136- ThrowHelper . ThrowArgument_InvalidAxis ( ) ;
136+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
137137
138- // Calculate total space needed.
139- nint totalLength = 0 ;
140- for ( int i = 0 ; i < tensors . Length ; i ++ )
141- totalLength += tensors [ i ] . FlattenedLength ;
138+ Tensor < T > tensor ;
142139
143- nint sumOfAxis = 0 ;
144140 // If axis != -1, make sure all dimensions except the one to concatenate on match.
145141 if ( dimension != - 1 )
146142 {
147- sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
143+ nint sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
148144 for ( int i = 1 ; i < tensors . Length ; i ++ )
149145 {
150146 if ( tensors [ 0 ] . Rank != tensors [ i ] . Rank )
@@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
157153 ThrowHelper . ThrowArgument_InvalidConcatenateShape ( ) ;
158154 }
159155 }
160- sumOfAxis += tensors [ i ] . Lengths [ dimension ] ;
156+ checked
157+ {
158+ sumOfAxis += tensors [ i ] . Lengths [ dimension ] ;
159+ }
161160 }
162- }
163161
164- Tensor < T > tensor ;
165- if ( dimension == - 1 )
166- {
167- tensor = Tensor . Create < T > ( [ totalLength ] ) ;
168- }
169- else
170- {
171162 nint [ ] lengths = new nint [ tensors [ 0 ] . Rank ] ;
172163 tensors [ 0 ] . Lengths . CopyTo ( lengths ) ;
173164 lengths [ dimension ] = sumOfAxis ;
174165 tensor = Tensor . Create < T > ( lengths ) ;
175166 }
167+ else
168+ {
169+ // Calculate total space needed.
170+ nint totalLength = 0 ;
171+ for ( int i = 0 ; i < tensors . Length ; i ++ )
172+ {
173+ checked
174+ {
175+ totalLength += tensors [ i ] . FlattenedLength ;
176+ }
177+ }
178+
179+ tensor = Tensor . Create < T > ( [ totalLength ] ) ;
180+ }
176181
177182 ConcatenateOnDimension ( dimension , tensors , tensor ) ;
178183 return tensor ;
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
201206 ThrowHelper . ThrowArgument_ConcatenateTooFewTensors ( ) ;
202207
203208 if ( dimension < - 1 || dimension > tensors [ 0 ] . Rank )
204- ThrowHelper . ThrowArgument_InvalidAxis ( ) ;
209+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
205210
206211 // Calculate total space needed.
207212 nint totalLength = 0 ;
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
212217 if ( dimension != - 1 )
213218 {
214219 nint sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
220+ int rank = tensors [ 0 ] . Rank ;
215221 for ( int i = 1 ; i < tensors . Length ; i ++ )
216222 {
217- if ( tensors [ 0 ] . Rank != tensors [ i ] . Rank )
223+ if ( rank != tensors [ i ] . Rank )
218224 ThrowHelper . ThrowArgument_InvalidConcatenateShape ( ) ;
219- for ( int j = 0 ; j < tensors [ 0 ] . Rank ; j ++ )
225+ for ( int j = 0 ; j < rank ; j ++ )
220226 {
221227 if ( j != dimension )
222228 {
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
228234 }
229235
230236 // Make sure the destination tensor has the correct shape.
231- nint [ ] lengths = new nint [ tensors [ 0 ] . Rank ] ;
237+ nint [ ] lengths = new nint [ rank ] ;
232238 tensors [ 0 ] . Lengths . CopyTo ( lengths ) ;
233239 lengths [ dimension ] = sumOfAxis ;
234240
@@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint
339345 /// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns>
340346 public static Tensor < T > Create < T > ( IEnumerable < T > enumerable , bool pinned = false )
341347 {
348+ T [ ] array = enumerable . ToArray ( ) ;
349+
342350 if ( pinned )
343351 {
344- T [ ] array = enumerable . ToArray ( ) ;
345-
346352 Tensor < T > tensor = CreateUninitialized < T > ( [ array . Length ] , pinned ) ;
347353 array . CopyTo ( tensor . _values ) ;
348354
349355 return tensor ;
350356 }
351357 else
352358 {
353- T [ ] array = enumerable . ToArray ( ) ;
354359 return Create ( array ) ;
355360 }
356361 }
@@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan
364369 /// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns>
365370 public static Tensor < T > Create < T > ( IEnumerable < T > enumerable , scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool pinned = false )
366371 {
372+ T [ ] array = enumerable . ToArray ( ) ;
373+
367374 if ( pinned )
368375 {
369- T [ ] array = enumerable . ToArray ( ) ;
370-
371376 Tensor < T > tensor = CreateUninitialized < T > ( lengths , strides , pinned ) ;
372377 array . CopyTo ( tensor . _values ) ;
373378
374379 return tensor ;
375380 }
376381 else
377382 {
378- T [ ] array = enumerable . ToArray ( ) ;
379383 return Create ( array , lengths , strides ) ;
380384 }
381385 }
@@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y)
620624 /// <param name="value">Value to update in the <paramref name="tensor"/>.</param>
621625 public static ref readonly TensorSpan < T > FilteredUpdate < T > ( in this TensorSpan < T > tensor , scoped in ReadOnlyTensorSpan < bool > filter , T value )
622626 {
623- if ( filter . Lengths . Length != tensor . Lengths . Length )
624- ThrowHelper . ThrowArgument_DimensionsNotSame ( nameof ( filter ) ) ;
625-
626- Span < T > srcSpan = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
627- Span < bool > filterSpan = MemoryMarshal . CreateSpan ( ref filter . _reference , ( int ) tensor . _shape . LinearLength ) ;
628-
629- for ( int i = 0 ; i < filterSpan . Length ; i ++ )
630- {
631- if ( filterSpan [ i ] )
632- {
633- srcSpan [ i ] = value ;
634- }
635- }
636-
627+ TensorOperation . ValidateCompatibility ( filter , tensor ) ;
628+ TensorOperation . Invoke < TensorOperation . FilteredUpdate < T > , bool , T , T > ( filter , value , tensor ) ;
637629 return ref tensor ;
638630 }
639631
@@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T>
646638 /// <param name="values">Values to update in the <paramref name="tensor"/>.</param>
647639 public static ref readonly TensorSpan < T > FilteredUpdate < T > ( in this TensorSpan < T > tensor , scoped in ReadOnlyTensorSpan < bool > filter , scoped in ReadOnlyTensorSpan < T > values )
648640 {
649- if ( filter . Lengths . Length != tensor . Lengths . Length )
650- ThrowHelper . ThrowArgument_DimensionsNotSame ( nameof ( filter ) ) ;
651- if ( values . Rank != 1 )
652- ThrowHelper . ThrowArgument_1DTensorRequired ( nameof ( values ) ) ;
653-
654- Span < T > dstSpan = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
655- Span < bool > filterSpan = MemoryMarshal . CreateSpan ( ref filter . _reference , ( int ) tensor . _shape . LinearLength ) ;
656- Span < T > valuesSpan = MemoryMarshal . CreateSpan ( ref values . _reference , ( int ) values . _shape . LinearLength ) ;
657-
658- int index = 0 ;
659- for ( int i = 0 ; i < filterSpan . Length ; i ++ )
660- {
661- if ( filterSpan [ i ] )
662- {
663- dstSpan [ i ] = valuesSpan [ index ++ ] ;
664- }
665- }
666-
641+ TensorOperation . ValidateCompatibility ( filter , values , tensor ) ;
642+ TensorOperation . Invoke < TensorOperation . FilteredUpdate < T > , bool , T , T > ( filter , values , tensor ) ;
667643 return ref tensor ;
668644 }
669645 #endregion
@@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
14091385 }
14101386 else
14111387 {
1388+ if ( ! dimensions . IsEmpty && dimensions . Length != tensor . Lengths . Length )
1389+ ThrowHelper . ThrowArgument_PermuteAxisOrder ( ) ;
1390+
14121391 scoped Span < nint > newLengths = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < nint > lengthsRentedBuffer ) ;
14131392 scoped Span < nint > newStrides = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < nint > stridesRentedBuffer ) ;
14141393 scoped Span < int > newLinearOrder = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < int > linearOrderRentedBuffer ) ;
@@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
14261405 }
14271406 else
14281407 {
1429- if ( dimensions . Length != tensor . Lengths . Length )
1430- ThrowHelper . ThrowArgument_PermuteAxisOrder ( ) ;
1431-
14321408 for ( int i = 0 ; i < dimensions . Length ; i ++ )
14331409 {
1410+ if ( dimensions [ i ] >= tensor . Lengths . Length || dimensions [ i ] < 0 )
1411+ {
1412+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
1413+ }
14341414 newLengths [ i ] = tensor . Lengths [ dimensions [ i ] ] ;
14351415 newStrides [ i ] = tensor . Strides [ dimensions [ i ] ] ;
14361416 newLinearOrder [ i ] = tensor . _shape . LinearRankOrder [ dimensions [ i ] ] ;
@@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
14671447
14681448 nint [ ] newLengths = lengths . ToArray ( ) ;
14691449 // Calculate wildcard info.
1470- if ( lengths . Contains ( - 1 ) )
1450+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1451+ if ( wildcardIndex >= 0 )
14711452 {
14721453 if ( lengths . Count ( - 1 ) > 1 )
14731454 ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
14791460 tempTotal /= lengths [ i ] ;
14801461 }
14811462 }
1482- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1463+ newLengths [ wildcardIndex ] = tempTotal ;
14831464 }
14841465
14851466 nint tempLinear = TensorPrimitives . Product ( newLengths ) ;
@@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
15381519 }
15391520
15401521 nint [ ] newLengths = lengths . ToArray ( ) ;
1541- // Calculate wildcard info.
1542- if ( lengths . Contains ( - 1 ) )
1522+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1523+ if ( wildcardIndex >= 0 )
15431524 {
15441525 if ( lengths . Count ( - 1 ) > 1 )
15451526 ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
15511532 tempTotal /= lengths [ i ] ;
15521533 }
15531534 }
1554- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1535+ newLengths [ wildcardIndex ] = tempTotal ;
15551536
15561537 }
15571538
@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
16151596
16161597 nint [ ] newLengths = lengths . ToArray ( ) ;
16171598 // Calculate wildcard info.
1618- if ( lengths . Contains ( - 1 ) )
1599+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1600+ if ( wildcardIndex >= 0 )
16191601 {
16201602 if ( lengths . Count ( - 1 ) > 1 )
16211603 ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
16271609 tempTotal /= lengths [ i ] ;
16281610 }
16291611 }
1630- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1612+ newLengths [ wildcardIndex ] = tempTotal ;
16311613
16321614 }
16331615
@@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
17011683 /// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
17021684 public static void ResizeTo < T > ( scoped in Tensor < T > tensor , in TensorSpan < T > destination )
17031685 {
1704- ReadOnlySpan < T > span = MemoryMarshal . CreateSpan ( ref Unsafe . Add ( ref tensor . AsTensorSpan ( ) . _reference , tensor . _start ) , ( int ) tensor . _values . Length - tensor . _start ) ;
1705- Span < T > ospan = MemoryMarshal . CreateSpan ( ref destination . _reference , ( int ) destination . _shape . LinearLength ) ;
1706- if ( ospan . Length >= span . Length )
1707- span . CopyTo ( ospan ) ;
1708- else
1709- span . Slice ( 0 , ospan . Length ) . CopyTo ( ospan ) ;
1686+ ResizeTo ( tensor . AsReadOnlyTensorSpan ( ) , destination ) ;
17101687 }
17111688
17121689 /// <summary>
@@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest
17171694 /// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
17181695 public static void ResizeTo < T > ( scoped in TensorSpan < T > tensor , in TensorSpan < T > destination )
17191696 {
1720- ReadOnlySpan < T > span = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
1721- Span < T > ospan = MemoryMarshal . CreateSpan ( ref destination . _reference , ( int ) destination . _shape . LinearLength ) ;
1722- if ( ospan . Length >= span . Length )
1723- span . CopyTo ( ospan ) ;
1724- else
1725- span . Slice ( 0 , ospan . Length ) . CopyTo ( ospan ) ;
1697+ ResizeTo ( tensor . AsReadOnlyTensorSpan ( ) , destination ) ;
17261698 }
17271699
17281700 /// <summary>
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
18901862 /// <param name="dimension">The axis to split on.</param>
18911863 public static Tensor < T > [ ] Split < T > ( scoped in ReadOnlyTensorSpan < T > tensor , int splitCount , nint dimension )
18921864 {
1865+ if ( dimension < 0 || dimension >= tensor . Rank )
1866+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
18931867 if ( tensor . Lengths [ ( int ) dimension ] % splitCount != 0 )
18941868 ThrowHelper . ThrowArgument_SplitNotSplitEvenly ( ) ;
18951869
@@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
22212195 ThrowHelper . ThrowArgument_StackShapesNotSame ( ) ;
22222196 }
22232197
2224- if ( dimension < 0 )
2225- dimension = tensors [ 0 ] . Rank - dimension ;
2198+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2199+ // with our call to Unsqueeze.
2200+ if ( dimension < 0 || dimension > tensors [ 0 ] . Rank )
2201+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
22262202
22272203 Tensor < T > [ ] outputs = new Tensor < T > [ tensors . Length ] ;
22282204 for ( int i = 0 ; i < tensors . Length ; i ++ )
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
22592235 ThrowHelper . ThrowArgument_StackShapesNotSame ( ) ;
22602236 }
22612237
2262- if ( dimension < 0 )
2263- dimension = tensors [ 0 ] . Rank - dimension ;
2238+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2239+ // with our call to Unsqueeze.
2240+ if ( dimension < 0 || dimension > tensors [ 0 ] . Rank )
2241+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
22642242
22652243 Tensor < T > [ ] outputs = new Tensor < T > [ tensors . Length ] ;
22662244 for ( int i = 0 ; i < tensors . Length ; i ++ )
0 commit comments