Skip to content

Commit 8d48558

Browse files
JetstreamRoySprowlNucs
authored andcommitted
Fix NDArray.mgrid and add unit tests, Corrected the Object comparer for NDArray to compare shape and data contents. (#314)
1 parent 39eead1 commit 8d48558

File tree

5 files changed

+165
-27
lines changed

5 files changed

+165
-27
lines changed
Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using NumSharp;
33

44
namespace NumSharp
@@ -8,7 +8,7 @@ public partial class NDArray
88
public (NDArray,NDArray) mgrid(NDArray nd2)
99
{
1010
if( !(this.ndim == 1 || nd2.ndim == 1))
11-
throw new IncorrectShapeException();
11+
throw new IncorrectShapeException("mgrid is implemented only for two single dimension arrays");
1212

1313
Array nd1Data = this.Storage.GetData();
1414
Array nd2Data = nd2.Storage.GetData();
@@ -22,21 +22,17 @@ public partial class NDArray
2222
Array res2Arr = res2.Storage.GetData();
2323

2424
int counter = 0;
25-
26-
for (int idx = 0; idx < nd2Data.Length; idx++)
25+
for (int row = 0; row < nd1Data.Length; row++)
2726
{
28-
for (int jdx = 0; jdx < nd1Data.Length; jdx++)
27+
for (int col = 0; col < nd2Data.Length; col++)
2928
{
30-
res1Arr.SetValue(nd1Data.GetValue(idx),counter);
31-
res2Arr.SetValue(nd2Data.GetValue(idx),counter);
29+
res1Arr.SetValue(nd1Data.GetValue(row), counter);
30+
res2Arr.SetValue(nd2Data.GetValue(col),counter);
3231
counter++;
3332
}
3433
}
35-
3634

3735
return (res1,res2);
3836
}
39-
4037
}
41-
42-
}
38+
}
Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
using System;
1+
using System;
22

33
namespace NumSharp
44
{
55
class IncorrectShapeException : System.Exception
66
{
77
public IncorrectShapeException() : base("This method does not work with this shape or was not already implemented.")
88
{
9-
9+
10+
}
11+
12+
public IncorrectShapeException(string msg) : base(msg)
13+
{
14+
1015
}
11-
}
12-
}
16+
}
17+
}

src/NumSharp.Core/Operations/Elementwise/NDArray.Equals.cs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,55 @@ public override bool Equals(object obj)
2020
{
2121
case NDArray safeCastObj:
2222
{
23+
// We aren't using array_equal here because it assumes that both arrays have
24+
// non-null Storage data and shapes.
2325
var thatData = safeCastObj.Storage?.GetData();
24-
if (thatData == null)
26+
var thisData = this.Storage?.GetData();
27+
if ((thatData == null && thisData != null) ||(thatData != null && thisData == null))
2528
{
2629
return false;
2730
}
31+
if (thisData != null)
32+
{
33+
if (thatData != null)
34+
{
35+
// Compare array contents, which is clumsy since we don't know the element type
36+
if (thisData.Length == thatData.Length)
37+
{
38+
for(int i = 0; i < thisData.Length; i++)
39+
{
40+
if (!thisData.GetValue(i).Equals(thatData.GetValue(i)))
41+
{
42+
return false;
43+
}
44+
}
45+
}
46+
else
47+
{
48+
return false;
49+
}
50+
}
51+
else
52+
{
53+
return false;
54+
}
55+
}
56+
else
57+
{
58+
if (thatData != null)
59+
{
60+
return false;
61+
}
62+
}
2863

29-
var thisData = this.Storage?.GetData();
30-
return thisData == thatData && safeCastObj.shape == this.shape;
64+
if (this.shape != null && safeCastObj.shape != null)
65+
{
66+
return this.shape.SequenceEqual(safeCastObj.shape);
67+
}
68+
else
69+
{
70+
return this.shape == safeCastObj.shape;
71+
}
3172
}
3273
case int val:
3374
return Data<int>(0) == val;
Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,89 @@
1-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
55
using NumSharp.Extensions;
66
using System.Linq;
77
using NumSharp;
8+
using System.Numerics;
89

910
namespace NumSharp.UnitTest.Creation
1011
{
1112
[TestClass]
1213
public class NdArrayMGridTest
1314
{
15+
// These C# NDArray declarations were generated using ndarray-generatory.py,
16+
// which is located in the README.md of this NumSharp.UnitTest project
17+
// using the following Python code:
18+
/*
19+
aa, bb = np.mgrid[0:5, 0:3]
20+
cc, dd = np.mgrid[0:3, 0:5]
21+
ee, ff = np.mgrid[0:5, 0:5]
22+
cSharp.asCode2D("a53", aa)
23+
cSharp.asCode2D("b53", bb)
24+
cSharp.asCode2D("a35", cc)
25+
cSharp.asCode2D("b35", dd)
26+
cSharp.asCode2D("a55", ee)
27+
cSharp.asCode2D("b55", ff)
28+
*/
29+
30+
static NDArray a53 = new NDArray(new Int32[] {
31+
0, 0, 0,
32+
1, 1, 1,
33+
2, 2, 2,
34+
3, 3, 3,
35+
4, 4, 4
36+
}, new Shape(new int[] { 5, 3 }));
37+
38+
static NDArray b53 = new NDArray(new Int32[] {
39+
0, 1, 2,
40+
0, 1, 2,
41+
0, 1, 2,
42+
0, 1, 2,
43+
0, 1, 2
44+
}, new Shape(new int[] { 5, 3 }));
45+
46+
static NDArray a35 = new NDArray(new Int32[] {
47+
0, 0, 0, 0, 0,
48+
1, 1, 1, 1, 1,
49+
2, 2, 2, 2, 2
50+
}, new Shape(new int[] { 3, 5 }));
51+
52+
static NDArray b35 = new NDArray(new Int32[] {
53+
0, 1, 2, 3, 4,
54+
0, 1, 2, 3, 4,
55+
0, 1, 2, 3, 4
56+
}, new Shape(new int[] { 3, 5 }));
57+
58+
static NDArray a55 = new NDArray(new Int32[] {
59+
0, 0, 0, 0, 0,
60+
1, 1, 1, 1, 1,
61+
2, 2, 2, 2, 2,
62+
3, 3, 3, 3, 3,
63+
4, 4, 4, 4, 4
64+
}, new Shape(new int[] { 5, 5 }));
65+
66+
static NDArray b55 = new NDArray(new Int32[] {
67+
0, 1, 2, 3, 4,
68+
0, 1, 2, 3, 4,
69+
0, 1, 2, 3, 4,
70+
0, 1, 2, 3, 4,
71+
0, 1, 2, 3, 4
72+
}, new Shape(new int[] { 5, 5 }));
73+
1474
[TestMethod]
1575
public void BaseTest()
1676
{
17-
var X = np.arange(1, 11, 2).mgrid(np.arange(-12, -3, 3));
18-
19-
NDArray x = X.Item1;
20-
NDArray y = X.Item2;
77+
var V53 = np.arange(0, 5, 1).mgrid(np.arange(0, 3, 1));
78+
var V35 = np.arange(0, 3, 1).mgrid(np.arange(0, 5, 1));
79+
var V55 = np.arange(0, 5, 1).mgrid(np.arange(0, 5, 1));
2180

22-
81+
Assert.AreEqual(V53.Item1, a53);
82+
Assert.AreEqual(V53.Item2, b53);
83+
Assert.AreEqual(V35.Item1, a35);
84+
Assert.AreEqual(V35.Item2, b35);
85+
Assert.AreEqual(V55.Item1, a55);
86+
Assert.AreEqual(V55.Item2, b55);
2387
}
24-
2588
}
26-
27-
}
89+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import math
2+
import numpy as np
3+
4+
# The asCode2D function generates NDArray declarations in C# for use in unit testing.
5+
# This avoids some of the tedium and errors of hand-generation.
6+
# For example, calling the function like this generates C# static variables named
7+
# 'a53' and 'b53' from numpy's mgrid:
8+
# aa, bb = np.mgrid[0:5, 0:3]
9+
# cSharp.asCode2D("a53", aa)
10+
# cSharp.asCode2D("b53", bb)
11+
12+
13+
class cSharp:
14+
def asCode2D(varName, v):
15+
if v.dtype.name == "int32":
16+
vType = "Int32"
17+
elif v.dtype.name == "float64":
18+
vType = "double"
19+
else:
20+
vType = "unknown"
21+
print(" static NDArray {0} = new NDArray(new {1}[] {{".format(varName, vType))
22+
valstr = ""
23+
commasToPrint = v.shape[0] * v.shape[1] - 1
24+
for i, row in enumerate(v):
25+
rowStr = " "
26+
for j, item in enumerate(row):
27+
rowStr = rowStr + "{}".format(item)
28+
if commasToPrint > 0:
29+
rowStr = rowStr + ", "
30+
commasToPrint -= 1
31+
#if (i < v)
32+
print(rowStr)
33+
print(" }}, new Shape(new int[] {{ {}, {} }}));".format(v.shape[0], v.shape[1]))
34+
print("")

0 commit comments

Comments
 (0)