Skip to content

Commit bad8767

Browse files
committed
NDArray: added GetIndicesFromSlice which is needed for indexing special cases
1 parent 502342c commit bad8767

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

src/NumSharp.Core/Selection/NDArray.Indexing.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,25 @@ private NDArray _extract_indices(NDArray[] mindices, bool isCollapsed, NDArray @
314314
}
315315
}
316316

317+
/// <summary>
318+
/// Converts a slice to indices for the special case where slices are mixed with NDArrays in this[...]
319+
/// </summary>
320+
/// <param name="shape"></param>
321+
/// <param name="slice"></param>
322+
/// <param name="axis"></param>
323+
/// <returns></returns>
324+
private static NDArray<int> GetIndicesFromSlice(Shape shape, Slice slice, int axis)
325+
{
326+
var dim = shape.Dimensions[axis];
327+
var slice_def = slice.ToSliceDef(dim); // this resolves negative slice indices
328+
return np.arange(slice_def.Start, slice_def.Start+slice_def.Step*slice_def.Count, slice.Step).MakeGeneric<int>();
329+
}
330+
331+
/// <summary>
332+
/// Slice the array with Python slice notation like this: ":, 2:7:1, ..., np.newaxis"
333+
/// </summary>
334+
/// <param name="slice">A string containing slice notations for every dimension, delimited by comma</param>
335+
/// <returns>A sliced view</returns>
317336
public NDArray this[string slice]
318337
{
319338
get => new NDArray(Storage.GetView(Slice.ParseSlices(slice)));

test/NumSharp.UnitTest/Selection/NDArray.Indexing.Test.cs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Diagnostics;
77
using System.Linq;
8+
using System.Reflection;
89
using System.Text.RegularExpressions;
910
using System.Threading;
1011
using FluentAssertions;
@@ -1122,8 +1123,8 @@ public void Masking_2D_over_3D()
11221123
// [20, 21, 22, 23, 24],
11231124
// [25, 26, 27, 28, 29]])
11241125
var x = np.arange(30).reshape(2, 3, 5);
1125-
var b = np.array(new[,] {{true, true, false}, {false, true, true } }).MakeGeneric<bool>();
1126-
y[b[":, 5"]].Should().BeOfValues(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29).And.BeShaped(4,5);
1126+
var b = np.array(new[,] { { true, true, false }, { false, true, true } }).MakeGeneric<bool>();
1127+
y[b[":, 5"]].Should().BeOfValues(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29).And.BeShaped(4, 5);
11271128
}
11281129

11291130
[TestMethod]
@@ -1147,7 +1148,7 @@ public void Combining_IndexArrays_with_Slices()
11471148
// [15, 16],
11481149
// [29, 30]])
11491150
var y = np.arange(35).reshape(5, 7);
1150-
y[np.array(0,2,4), "1:3"].Should().BeOfValues(1, 2, 15, 16, 29, 30).And.BeShaped(3,2);
1151+
y[np.array(0, 2, 4), "1:3"].Should().BeOfValues(1, 2, 15, 16, 29, 30).And.BeShaped(3, 2);
11511152
}
11521153

11531154

@@ -1168,7 +1169,21 @@ public void Combining_MaskArrays_with_Slices()
11681169
// [29, 30]])
11691170
var y = np.arange(35).reshape(5, 7);
11701171
var b = y > 20;
1171-
y[b[":, 5"], "1:3"].Should().BeOfValues(22,23, 29, 30).And.BeShaped(2, 2);
1172+
y[b[":, 5"], "1:3"].Should().BeOfValues(22, 23, 29, 30).And.BeShaped(2, 2);
1173+
}
1174+
1175+
// use this as a proxy for the private static method GetIndicesFromSlice of NDArray
1176+
private NDArray<int> GetIndicesFromSlice(Shape shape, Slice slice, int axis)
1177+
{
1178+
var method = typeof(NDArray).GetMethod("GetIndicesFromSlice", BindingFlags.NonPublic | BindingFlags.Static);
1179+
return (NDArray<int>)method.Invoke(null, new object[] { shape, slice, axis });
1180+
}
1181+
1182+
[TestMethod]
1183+
public void GetIndicesFromSlice_Test()
1184+
{
1185+
GetIndicesFromSlice((3, 4, 3), new Slice("::2"), 1).Should().BeOfValues(0, 2).And.BeShaped(2);
1186+
GetIndicesFromSlice((3, 4, 3), new Slice("-1::-1"), 0).Should().BeOfValues(2,1,0);
11721187
}
11731188
}
11741189
}

0 commit comments

Comments
 (0)