Skip to content

Commit 0e86e8d

Browse files
committed
Make C# enums work as proper enums in Python
Avoid converting C# enums to long in Python
1 parent 68a2183 commit 0e86e8d

File tree

5 files changed

+799
-2
lines changed

5 files changed

+799
-2
lines changed

src/embed_tests/EnumTests.cs

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
using NUnit.Framework;
5+
6+
using Python.Runtime;
7+
8+
namespace Python.EmbeddingTest
9+
{
10+
public class EnumTests
11+
{
12+
[OneTimeSetUp]
13+
public void SetUp()
14+
{
15+
PythonEngine.Initialize();
16+
}
17+
18+
[OneTimeTearDown]
19+
public void Dispose()
20+
{
21+
PythonEngine.Shutdown();
22+
}
23+
24+
public enum Direction
25+
{
26+
Down = -2,
27+
Flat = 0,
28+
Up = 2,
29+
}
30+
31+
[Test]
32+
public void CSharpEnumsBehaveAsEnumsInPython()
33+
{
34+
using var _ = Py.GIL();
35+
var module = PyModule.FromString("CSharpEnumsBehaveAsEnumsInPython", $@"
36+
from clr import AddReference
37+
AddReference(""Python.EmbeddingTest"")
38+
39+
from Python.EmbeddingTest import *
40+
41+
def enum_is_right_type(enum_value=EnumTests.Direction.Up):
42+
return isinstance(enum_value, EnumTests.Direction)
43+
");
44+
45+
Assert.IsTrue(module.InvokeMethod("enum_is_right_type").As<bool>());
46+
47+
// Also test passing the enum value from C# to Python
48+
using var pyEnumValue = Direction.Up.ToPython();
49+
Assert.IsTrue(module.InvokeMethod("enum_is_right_type", pyEnumValue).As<bool>());
50+
}
51+
52+
private PyModule GetTestOperatorsModule(string @operator, Direction operand1, double operand2)
53+
{
54+
var operand1Str = $"{nameof(EnumTests)}.{nameof(Direction)}.{operand1}";
55+
return PyModule.FromString("GetTestOperatorsModule", $@"
56+
from clr import AddReference
57+
AddReference(""Python.EmbeddingTest"")
58+
59+
from Python.EmbeddingTest import *
60+
61+
def operation1():
62+
return {operand1Str} {@operator} {operand2}
63+
64+
def operation2():
65+
return {operand2} {@operator} {operand1Str}
66+
");
67+
}
68+
69+
[TestCase(" *", Direction.Down, 2, -4)]
70+
[TestCase("/", Direction.Down, 2, -1)]
71+
[TestCase("+", Direction.Down, 2, 0)]
72+
[TestCase("-", Direction.Down, 2, -4)]
73+
[TestCase("*", Direction.Flat, 2, 0)]
74+
[TestCase("/", Direction.Flat, 2, 0)]
75+
[TestCase("+", Direction.Flat, 2, 2)]
76+
[TestCase("-", Direction.Flat, 2, -2)]
77+
[TestCase("*", Direction.Up, 2, 4)]
78+
[TestCase("/", Direction.Up, 2, 1)]
79+
[TestCase("+", Direction.Up, 2, 4)]
80+
[TestCase("-", Direction.Up, 2, 0)]
81+
public void ArithmeticOperatorsWorkWithoutExplicitCast(string @operator, Direction operand1, double operand2, double expectedResult)
82+
{
83+
using var _ = Py.GIL();
84+
var module = GetTestOperatorsModule(@operator, operand1, operand2);
85+
86+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<double>());
87+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation2").As<double>());
88+
}
89+
90+
[TestCase("==", Direction.Down, -2, true)]
91+
[TestCase("==", Direction.Down, 0, false)]
92+
[TestCase("==", Direction.Down, 2, false)]
93+
[TestCase("==", Direction.Flat, -2, false)]
94+
[TestCase("==", Direction.Flat, 0, true)]
95+
[TestCase("==", Direction.Flat, 2, false)]
96+
[TestCase("==", Direction.Up, -2, false)]
97+
[TestCase("==", Direction.Up, 0, false)]
98+
[TestCase("==", Direction.Up, 2, true)]
99+
[TestCase("!=", Direction.Down, -2, false)]
100+
[TestCase("!=", Direction.Down, 0, true)]
101+
[TestCase("!=", Direction.Down, 2, true)]
102+
[TestCase("!=", Direction.Flat, -2, true)]
103+
[TestCase("!=", Direction.Flat, 0, false)]
104+
[TestCase("!=", Direction.Flat, 2, true)]
105+
[TestCase("!=", Direction.Up, -2, true)]
106+
[TestCase("!=", Direction.Up, 0, true)]
107+
[TestCase("!=", Direction.Up, 2, false)]
108+
[TestCase("<", Direction.Down, -3, false)]
109+
[TestCase("<", Direction.Down, -2, false)]
110+
[TestCase("<", Direction.Down, 0, true)]
111+
[TestCase("<", Direction.Down, 2, true)]
112+
[TestCase("<", Direction.Flat, -2, false)]
113+
[TestCase("<", Direction.Flat, 0, false)]
114+
[TestCase("<", Direction.Flat, 2, true)]
115+
[TestCase("<", Direction.Up, -2, false)]
116+
[TestCase("<", Direction.Up, 0, false)]
117+
[TestCase("<", Direction.Up, 2, false)]
118+
[TestCase("<", Direction.Up, 3, true)]
119+
[TestCase("<=", Direction.Down, -3, false)]
120+
[TestCase("<=", Direction.Down, -2, true)]
121+
[TestCase("<=", Direction.Down, 0, true)]
122+
[TestCase("<=", Direction.Down, 2, true)]
123+
[TestCase("<=", Direction.Flat, -2, false)]
124+
[TestCase("<=", Direction.Flat, 0, true)]
125+
[TestCase("<=", Direction.Flat, 2, true)]
126+
[TestCase("<=", Direction.Up, -2, false)]
127+
[TestCase("<=", Direction.Up, 0, false)]
128+
[TestCase("<=", Direction.Up, 2, true)]
129+
[TestCase("<=", Direction.Up, 3, true)]
130+
[TestCase(">", Direction.Down, -3, true)]
131+
[TestCase(">", Direction.Down, -2, false)]
132+
[TestCase(">", Direction.Down, 0, false)]
133+
[TestCase(">", Direction.Down, 2, false)]
134+
[TestCase(">", Direction.Flat, -2, true)]
135+
[TestCase(">", Direction.Flat, 0, false)]
136+
[TestCase(">", Direction.Flat, 2, false)]
137+
[TestCase(">", Direction.Up, -2, true)]
138+
[TestCase(">", Direction.Up, 0, true)]
139+
[TestCase(">", Direction.Up, 2, false)]
140+
[TestCase(">", Direction.Up, 3, false)]
141+
[TestCase(">=", Direction.Down, -3, true)]
142+
[TestCase(">=", Direction.Down, -2, true)]
143+
[TestCase(">=", Direction.Down, 0, false)]
144+
[TestCase(">=", Direction.Down, 2, false)]
145+
[TestCase(">=", Direction.Flat, -2, true)]
146+
[TestCase(">=", Direction.Flat, 0, true)]
147+
[TestCase(">=", Direction.Flat, 2, false)]
148+
[TestCase(">=", Direction.Up, -2, true)]
149+
[TestCase(">=", Direction.Up, 0, true)]
150+
[TestCase(">=", Direction.Up, 2, true)]
151+
[TestCase(">=", Direction.Up, 3, false)]
152+
public void IntComparisonOperatorsWorkWithoutExplicitCast(string @operator, Direction operand1, int operand2, bool expectedResult)
153+
{
154+
using var _ = Py.GIL();
155+
var module = GetTestOperatorsModule(@operator, operand1, operand2);
156+
157+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
158+
159+
var invertedOperationExpectedResult = (@operator.StartsWith('<') || @operator.StartsWith('>')) && Convert.ToInt64(operand1) != operand2
160+
? !expectedResult
161+
: expectedResult;
162+
Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As<bool>());
163+
}
164+
165+
[TestCase("==", Direction.Down, -2.0, true)]
166+
[TestCase("==", Direction.Down, -2.00001, false)]
167+
[TestCase("==", Direction.Down, -1.99999, false)]
168+
[TestCase("==", Direction.Down, 0.0, false)]
169+
[TestCase("==", Direction.Down, 2.0, false)]
170+
[TestCase("==", Direction.Flat, -2.0, false)]
171+
[TestCase("==", Direction.Flat, 0.0, true)]
172+
[TestCase("==", Direction.Flat, 0.00001, false)]
173+
[TestCase("==", Direction.Flat, -0.00001, false)]
174+
[TestCase("==", Direction.Flat, 2.0, false)]
175+
[TestCase("==", Direction.Up, -2.0, false)]
176+
[TestCase("==", Direction.Up, 0.0, false)]
177+
[TestCase("==", Direction.Up, 2.0, true)]
178+
[TestCase("==", Direction.Up, 2.00001, false)]
179+
[TestCase("==", Direction.Up, 1.99999, false)]
180+
[TestCase("!=", Direction.Down, -2.0, false)]
181+
[TestCase("!=", Direction.Down, -2.00001, true)]
182+
[TestCase("!=", Direction.Down, -1.99999, true)]
183+
[TestCase("!=", Direction.Down, 0.0, true)]
184+
[TestCase("!=", Direction.Down, 2.0, true)]
185+
[TestCase("!=", Direction.Flat, -2.0, true)]
186+
[TestCase("!=", Direction.Flat, 0.0, false)]
187+
[TestCase("!=", Direction.Flat, 0.00001, true)]
188+
[TestCase("!=", Direction.Flat, -0.00001, true)]
189+
[TestCase("!=", Direction.Flat, 2.0, true)]
190+
[TestCase("!=", Direction.Up, -2.0, true)]
191+
[TestCase("!=", Direction.Up, 0.0, true)]
192+
[TestCase("!=", Direction.Up, 2.0, false)]
193+
[TestCase("!=", Direction.Up, 2.00001, true)]
194+
[TestCase("!=", Direction.Up, 1.99999, true)]
195+
[TestCase("<", Direction.Down, -3.0, false)]
196+
[TestCase("<", Direction.Down, -2.00001, false)]
197+
[TestCase("<", Direction.Down, -2.0, false)]
198+
[TestCase("<", Direction.Down, -1.99999, true)]
199+
[TestCase("<", Direction.Down, 0.0, true)]
200+
[TestCase("<", Direction.Down, 2.0, true)]
201+
[TestCase("<", Direction.Flat, -2.0, false)]
202+
[TestCase("<", Direction.Flat, -0.00001, false)]
203+
[TestCase("<", Direction.Flat, 0.0, false)]
204+
[TestCase("<", Direction.Flat, 0.00001, true)]
205+
[TestCase("<", Direction.Flat, 2.0, true)]
206+
[TestCase("<", Direction.Up, -2.0, false)]
207+
[TestCase("<", Direction.Up, 0.0, false)]
208+
[TestCase("<", Direction.Up, 1.99999, false)]
209+
[TestCase("<", Direction.Up, 2.0, false)]
210+
[TestCase("<", Direction.Up, 2.00001, true)]
211+
[TestCase("<", Direction.Up, 3.0, true)]
212+
[TestCase("<=", Direction.Down, -3.0, false)]
213+
[TestCase("<=", Direction.Down, -2.00001, false)]
214+
[TestCase("<=", Direction.Down, -2.0, true)]
215+
[TestCase("<=", Direction.Down, -1.99999, true)]
216+
[TestCase("<=", Direction.Down, 0.0, true)]
217+
[TestCase("<=", Direction.Down, 2.0, true)]
218+
[TestCase("<=", Direction.Flat, -2.0, false)]
219+
[TestCase("<=", Direction.Flat, -0.00001, false)]
220+
[TestCase("<=", Direction.Flat, 0.0, true)]
221+
[TestCase("<=", Direction.Flat, 0.00001, true)]
222+
[TestCase("<=", Direction.Flat, 2.0, true)]
223+
[TestCase("<=", Direction.Up, -2.0, false)]
224+
[TestCase("<=", Direction.Up, 0.0, false)]
225+
[TestCase("<=", Direction.Up, 1.99999, false)]
226+
[TestCase("<=", Direction.Up, 2.0, true)]
227+
[TestCase("<=", Direction.Up, 2.00001, true)]
228+
[TestCase("<=", Direction.Up, 3.0, true)]
229+
[TestCase(">", Direction.Down, -3.0, true)]
230+
[TestCase(">", Direction.Down, -2.00001, true)]
231+
[TestCase(">", Direction.Down, -2.0, false)]
232+
[TestCase(">", Direction.Down, -1.99999, false)]
233+
[TestCase(">", Direction.Down, 0.0, false)]
234+
[TestCase(">", Direction.Down, 2.0, false)]
235+
[TestCase(">", Direction.Flat, -2.0, true)]
236+
[TestCase(">", Direction.Flat, -0.00001, true)]
237+
[TestCase(">", Direction.Flat, 0.0, false)]
238+
[TestCase(">", Direction.Flat, 0.00001, false)]
239+
[TestCase(">", Direction.Flat, 2.0, false)]
240+
[TestCase(">", Direction.Up, -2.0, true)]
241+
[TestCase(">", Direction.Up, 0.0, true)]
242+
[TestCase(">", Direction.Up, 1.99999, true)]
243+
[TestCase(">", Direction.Up, 2.0, false)]
244+
[TestCase(">", Direction.Up, 2.00001, false)]
245+
[TestCase(">", Direction.Up, 3.0, false)]
246+
[TestCase(">=", Direction.Down, -3.0, true)]
247+
[TestCase(">=", Direction.Down, -2.00001, true)]
248+
[TestCase(">=", Direction.Down, -2.0, true)]
249+
[TestCase(">=", Direction.Down, -1.99999, false)]
250+
[TestCase(">=", Direction.Down, 0.0, false)]
251+
[TestCase(">=", Direction.Down, 2.0, false)]
252+
[TestCase(">=", Direction.Flat, -2.0, true)]
253+
[TestCase(">=", Direction.Flat, -0.00001, true)]
254+
[TestCase(">=", Direction.Flat, 0.0, true)]
255+
[TestCase(">=", Direction.Flat, 0.00001, false)]
256+
[TestCase(">=", Direction.Flat, 2.0, false)]
257+
[TestCase(">=", Direction.Up, -2.0, true)]
258+
[TestCase(">=", Direction.Up, 0.0, true)]
259+
[TestCase(">=", Direction.Up, 1.99999, true)]
260+
[TestCase(">=", Direction.Up, 2.0, true)]
261+
[TestCase(">=", Direction.Up, 2.00001, false)]
262+
[TestCase(">=", Direction.Up, 3.0, false)]
263+
public void FloatComparisonOperatorsWorkWithoutExplicitCast(string @operator, Direction operand1, double operand2, bool expectedResult)
264+
{
265+
using var _ = Py.GIL();
266+
var module = GetTestOperatorsModule(@operator, operand1, operand2);
267+
268+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
269+
270+
var invertedOperationExpectedResult = (@operator.StartsWith('<') || @operator.StartsWith('>')) && Convert.ToInt64(operand1) != operand2
271+
? !expectedResult
272+
: expectedResult;
273+
Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As<bool>());
274+
}
275+
276+
public static IEnumerable<TestCaseData> SameEnumTypeComparisonOperatorsTestCases
277+
{
278+
get
279+
{
280+
var operators = new[] { "==", "!=", "<", "<=", ">", ">=" };
281+
var enumValues = Enum.GetValues<Direction>();
282+
283+
foreach (var enumValue in enumValues)
284+
{
285+
foreach (var enumValue2 in enumValues)
286+
{
287+
yield return new TestCaseData("==", enumValue, enumValue2, enumValue == enumValue2);
288+
yield return new TestCaseData("!=", enumValue, enumValue2, enumValue != enumValue2);
289+
yield return new TestCaseData("<", enumValue, enumValue2, enumValue < enumValue2);
290+
yield return new TestCaseData("<=", enumValue, enumValue2, enumValue <= enumValue2);
291+
yield return new TestCaseData(">", enumValue, enumValue2, enumValue > enumValue2);
292+
yield return new TestCaseData(">=", enumValue, enumValue2, enumValue >= enumValue2);
293+
}
294+
}
295+
}
296+
}
297+
298+
[TestCaseSource(nameof(SameEnumTypeComparisonOperatorsTestCases))]
299+
public void SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast(string @operator, Direction operand1, Direction operand2, bool expectedResult)
300+
{
301+
using var _ = Py.GIL();
302+
var module = PyModule.FromString("SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast", $@"
303+
from clr import AddReference
304+
AddReference(""Python.EmbeddingTest"")
305+
306+
from Python.EmbeddingTest import *
307+
308+
def operation():
309+
return {nameof(EnumTests)}.{nameof(Direction)}.{operand1} {@operator} {nameof(EnumTests)}.{nameof(Direction)}.{operand2}
310+
");
311+
312+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation").As<bool>());
313+
}
314+
315+
[TestCase("==", Direction.Down, "Down", true)]
316+
[TestCase("==", Direction.Down, "Flat", false)]
317+
[TestCase("==", Direction.Down, "Up", false)]
318+
[TestCase("==", Direction.Flat, "Down", false)]
319+
[TestCase("==", Direction.Flat, "Flat", true)]
320+
[TestCase("==", Direction.Flat, "Up", false)]
321+
[TestCase("==", Direction.Up, "Down", false)]
322+
[TestCase("==", Direction.Up, "Flat", false)]
323+
[TestCase("==", Direction.Up, "Up", true)]
324+
[TestCase("!=", Direction.Down, "Down", false)]
325+
[TestCase("!=", Direction.Down, "Flat", true)]
326+
[TestCase("!=", Direction.Down, "Up", true)]
327+
[TestCase("!=", Direction.Flat, "Down", true)]
328+
[TestCase("!=", Direction.Flat, "Flat", false)]
329+
[TestCase("!=", Direction.Flat, "Up", true)]
330+
[TestCase("!=", Direction.Up, "Down", true)]
331+
[TestCase("!=", Direction.Up, "Flat", true)]
332+
[TestCase("!=", Direction.Up, "Up", false)]
333+
public void EnumComparisonOperatorsWorkWithString(string @operator, Direction operand1, string operand2, bool expectedResult)
334+
{
335+
using var _ = Py.GIL();
336+
var module = PyModule.FromString("EnumComparisonOperatorsWorkWithString", $@"
337+
from clr import AddReference
338+
AddReference(""Python.EmbeddingTest"")
339+
340+
from Python.EmbeddingTest import *
341+
342+
def operation1():
343+
return {nameof(EnumTests)}.{nameof(Direction)}.{operand1} {@operator} ""{operand2}""
344+
345+
def operation2():
346+
return ""{operand2}"" {@operator} {nameof(EnumTests)}.{nameof(Direction)}.{operand1}
347+
");
348+
349+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
350+
Assert.AreEqual(expectedResult, module.InvokeMethod("operation2").As<bool>());
351+
}
352+
}
353+
}

0 commit comments

Comments
 (0)