Skip to content

Commit fc4e67e

Browse files
authored
Merge pull request #95 from jhonabreul/bug-matching-pyobject-args-overloads-first
Fix matching PyObject arguments overloads first
2 parents 2c675ce + e26db13 commit fc4e67e

File tree

5 files changed

+180
-25
lines changed

5 files changed

+180
-25
lines changed

src/embed_tests/TestMethodBinder.cs

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using NUnit.Framework;
55
using System.Collections.Generic;
66
using System.Diagnostics;
7-
using static Python.Runtime.Py;
87

98
namespace Python.EmbeddingTest
109
{
@@ -71,6 +70,7 @@ def TestEnumerable(self):
7170
public void SetUp()
7271
{
7372
PythonEngine.Initialize();
73+
using var _ = Py.GIL();
7474

7575
try
7676
{
@@ -80,10 +80,7 @@ public void SetUp()
8080
{
8181
}
8282

83-
using (Py.GIL())
84-
{
85-
module = PyModule.FromString("module", testModule).GetAttr("PythonModel").Invoke();
86-
}
83+
module = PyModule.FromString("module", testModule).GetAttr("PythonModel").Invoke();
8784
}
8885

8986
[OneTimeTearDown]
@@ -926,6 +923,131 @@ def call_method(instance):
926923
Assert.IsFalse(Exceptions.ErrorOccurred());
927924
}
928925

926+
public class CSharpClass2
927+
{
928+
public string CalledMethodMessage { get; private set; } = string.Empty;
929+
930+
public void Clear()
931+
{
932+
CalledMethodMessage = string.Empty;
933+
}
934+
935+
public void Method()
936+
{
937+
CalledMethodMessage = "Overload 1";
938+
}
939+
940+
public void Method(CSharpClass csharpClassArgument, decimal decimalArgument = 1.2m, PyObject pyObjectKwArgument = null)
941+
{
942+
CalledMethodMessage = "Overload 2";
943+
}
944+
945+
public void Method(PyObject pyObjectArgument, decimal decimalArgument = 1.2m, object objectArgument = null)
946+
{
947+
CalledMethodMessage = "Overload 3";
948+
}
949+
950+
// This must be matched when passing just a single argument and it's a PyObject,
951+
// event though the PyObject kwarg in the second overload has more precedence.
952+
// But since it will not be passed, this overload must be called.
953+
public void Method(PyObject pyObjectArgument, decimal decimalArgument = 1.2m, int intArgument = 0)
954+
{
955+
CalledMethodMessage = "Overload 4";
956+
}
957+
}
958+
959+
[Test]
960+
public void PyObjectArgsHavePrecedenceOverOtherTypes()
961+
{
962+
using var _ = Py.GIL();
963+
964+
var instance = new CSharpClass2();
965+
using var pyInstance = instance.ToPython();
966+
using var pyArg = new CSharpClass().ToPython();
967+
968+
Assert.DoesNotThrow(() =>
969+
{
970+
// We are passing a PyObject and not using the named arguments,
971+
// that overload must be called without converting the PyObject to CSharpClass
972+
pyInstance.InvokeMethod("Method", pyArg);
973+
});
974+
975+
Assert.AreEqual("Overload 4", instance.CalledMethodMessage);
976+
Assert.IsFalse(Exceptions.ErrorOccurred());
977+
instance.Clear();
978+
979+
// With the first named argument
980+
Assert.DoesNotThrow(() =>
981+
{
982+
using var kwargs = Py.kw("decimalArgument", 1.234m);
983+
pyInstance.InvokeMethod("Method", new[] { pyArg }, kwargs);
984+
});
985+
986+
Assert.AreEqual("Overload 4", instance.CalledMethodMessage);
987+
Assert.IsFalse(Exceptions.ErrorOccurred());
988+
instance.Clear();
989+
990+
// Snake case version
991+
Assert.DoesNotThrow(() =>
992+
{
993+
using var kwargs = Py.kw("decimal_argument", 1.234m);
994+
pyInstance.InvokeMethod("method", new[] { pyArg }, kwargs);
995+
});
996+
997+
Assert.AreEqual("Overload 4", instance.CalledMethodMessage);
998+
Assert.IsFalse(Exceptions.ErrorOccurred());
999+
}
1000+
1001+
[Test]
1002+
public void OtherTypesHavePrecedenceOverPyObjectArgsIfMoreArgsAreMatched()
1003+
{
1004+
using var _ = Py.GIL();
1005+
1006+
var instance = new CSharpClass2();
1007+
using var pyInstance = instance.ToPython();
1008+
using var pyArg = new CSharpClass().ToPython();
1009+
1010+
Assert.DoesNotThrow(() =>
1011+
{
1012+
using var kwargs = Py.kw("pyObjectKwArgument", new CSharpClass2());
1013+
pyInstance.InvokeMethod("Method", new[] { pyArg }, kwargs);
1014+
});
1015+
1016+
Assert.AreEqual("Overload 2", instance.CalledMethodMessage);
1017+
Assert.IsFalse(Exceptions.ErrorOccurred());
1018+
instance.Clear();
1019+
1020+
Assert.DoesNotThrow(() =>
1021+
{
1022+
using var kwargs = Py.kw("py_object_kw_argument", new CSharpClass2());
1023+
pyInstance.InvokeMethod("method", new[] { pyArg }, kwargs);
1024+
});
1025+
1026+
Assert.AreEqual("Overload 2", instance.CalledMethodMessage);
1027+
Assert.IsFalse(Exceptions.ErrorOccurred());
1028+
instance.Clear();
1029+
1030+
Assert.DoesNotThrow(() =>
1031+
{
1032+
using var kwargs = Py.kw("objectArgument", "somestring");
1033+
pyInstance.InvokeMethod("Method", new[] { pyArg }, kwargs);
1034+
});
1035+
1036+
Assert.AreEqual("Overload 3", instance.CalledMethodMessage);
1037+
Assert.IsFalse(Exceptions.ErrorOccurred());
1038+
instance.Clear();
1039+
1040+
Assert.DoesNotThrow(() =>
1041+
{
1042+
using var kwargs = Py.kw("object_argument", "somestring");
1043+
pyInstance.InvokeMethod("method", new[] { pyArg }, kwargs);
1044+
});
1045+
1046+
Assert.AreEqual("Overload 3", instance.CalledMethodMessage);
1047+
Assert.IsFalse(Exceptions.ErrorOccurred());
1048+
instance.Clear();
1049+
}
1050+
9291051
[Test]
9301052
public void BindsConstructorToSnakeCasedArgumentsVersion([Values] bool useCamelCase, [Values] bool passOptionalArgument)
9311053
{

src/perf_tests/Python.PerformanceTests.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
1414
</PackageReference>
1515
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.*" />
16-
<PackageReference Include="quantconnect.pythonnet" Version="2.0.39" GeneratePathProperty="true">
16+
<PackageReference Include="quantconnect.pythonnet" Version="2.0.40" GeneratePathProperty="true">
1717
<IncludeAssets>compile</IncludeAssets>
1818
</PackageReference>
1919
</ItemGroup>
@@ -25,7 +25,7 @@
2525
</Target>
2626

2727
<Target Name="CopyBaseline" AfterTargets="Build">
28-
<Copy SourceFiles="$(NuGetPackageRoot)quantconnect.pythonnet\2.0.39\lib\net6.0\Python.Runtime.dll" DestinationFolder="$(OutDir)baseline" />
28+
<Copy SourceFiles="$(NuGetPackageRoot)quantconnect.pythonnet\2.0.40\lib\net6.0\Python.Runtime.dll" DestinationFolder="$(OutDir)baseline" />
2929
</Target>
3030

3131
<Target Name="CopyNewBuild" AfterTargets="Build">

src/runtime/MethodBinder.cs

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.Linq;
5-
using System.Numerics;
65
using System.Reflection;
76
using System.Text;
87

@@ -320,22 +319,46 @@ internal List<MethodInformation> GetMethods()
320319
/// See: https://github.com/jythontools/jython/blob/master/src/org/python/core/ReflectedArgs.java#L192
321320
/// </remarks>
322321
private static int GetPrecedence(MethodInformation methodInformation)
322+
{
323+
return GetMatchedArgumentsPrecedence(methodInformation, null, null);
324+
}
325+
326+
/// <summary>
327+
/// Gets the precedence of a method's arguments, considering only those arguments that have been matched,
328+
/// that is, those that are not default values.
329+
/// </summary>
330+
private static int GetMatchedArgumentsPrecedence(MethodInformation methodInformation, int? matchedPositionalArgsCount, IEnumerable<string> matchedKwargsNames)
323331
{
324332
ParameterInfo[] pi = methodInformation.ParameterInfo;
325333
var mi = methodInformation.MethodBase;
326334
int val = mi.IsStatic ? 3000 : 0;
327-
int num = pi.Length;
335+
var isOperatorMethod = OperatorMethod.IsOperatorMethod(methodInformation.MethodBase);
328336

329337
val += mi.IsGenericMethod ? 1 : 0;
330-
for (var i = 0; i < num; i++)
338+
339+
if (!matchedPositionalArgsCount.HasValue)
340+
{
341+
for (var i = 0; i < pi.Length; i++)
342+
{
343+
val += ArgPrecedence(pi[i].ParameterType, isOperatorMethod);
344+
}
345+
}
346+
else
331347
{
332-
val += ArgPrecedence(pi[i].ParameterType, methodInformation);
348+
matchedKwargsNames ??= Array.Empty<string>();
349+
for (var i = 0; i < pi.Length; i++)
350+
{
351+
if (i < matchedPositionalArgsCount || matchedKwargsNames.Contains(methodInformation.ParameterNames[i]))
352+
{
353+
val += ArgPrecedence(pi[i].ParameterType, isOperatorMethod);
354+
}
355+
}
333356
}
334357

335358
var info = mi as MethodInfo;
336359
if (info != null)
337360
{
338-
val += ArgPrecedence(info.ReturnType, methodInformation);
361+
val += ArgPrecedence(info.ReturnType, isOperatorMethod);
339362
if (mi.DeclaringType == mi.ReflectedType)
340363
{
341364
val += methodInformation.IsOriginal ? 0 : 300000;
@@ -352,15 +375,15 @@ private static int GetPrecedence(MethodInformation methodInformation)
352375
/// <summary>
353376
/// Return a precedence value for a particular Type object.
354377
/// </summary>
355-
internal static int ArgPrecedence(Type t, MethodInformation mi)
378+
internal static int ArgPrecedence(Type t, bool isOperatorMethod)
356379
{
357380
Type objectType = typeof(object);
358381
if (t == objectType)
359382
{
360383
return 3000;
361384
}
362385

363-
if (t.IsAssignableFrom(typeof(PyObject)) && !OperatorMethod.IsOperatorMethod(mi.MethodBase))
386+
if (t.IsAssignableFrom(typeof(PyObject)) && !isOperatorMethod)
364387
{
365388
return -1;
366389
}
@@ -372,7 +395,7 @@ internal static int ArgPrecedence(Type t, MethodInformation mi)
372395
{
373396
return 2500;
374397
}
375-
return 100 + ArgPrecedence(e, mi);
398+
return 100 + ArgPrecedence(e, isOperatorMethod);
376399
}
377400

378401
TypeCode tc = Type.GetTypeCode(t);
@@ -452,6 +475,7 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
452475
var methods = info == null ? GetMethods()
453476
: new List<MethodInformation>(1) { new MethodInformation(info, true) };
454477

478+
int pyArgCount = (int)Runtime.PyTuple_Size(args);
455479
var matches = new List<MatchedMethod>(methods.Count);
456480
List<MatchedMethod> matchesUsingImplicitConversion = null;
457481

@@ -463,7 +487,6 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
463487
var pi = methodInformation.ParameterInfo;
464488
// Avoid accessing the parameter names property unless necessary
465489
var paramNames = hasNamedArgs ? methodInformation.ParameterNames : Array.Empty<string>();
466-
int pyArgCount = (int)Runtime.PyTuple_Size(args);
467490

468491
// Special case for operators
469492
bool isOperator = OperatorMethod.IsOperatorMethod(mi);
@@ -695,7 +718,7 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
695718
}
696719
}
697720

698-
var match = new MatchedMethod(kwargsMatched, margs, outs, mi);
721+
var match = new MatchedMethod(kwargsMatched, margs, outs, methodInformation);
699722
if (usedImplicitConversion)
700723
{
701724
if (matchesUsingImplicitConversion == null)
@@ -718,8 +741,17 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
718741
// We favor matches that do not use implicit conversion
719742
var matchesTouse = matches.Count > 0 ? matches : matchesUsingImplicitConversion;
720743

721-
// The best match would be the one with the most named arguments matched
722-
var bestMatch = matchesTouse.MaxBy(x => x.KwargsMatched);
744+
// The best match would be the one with the most named arguments matched.
745+
// But if multiple matches have the same max number of named arguments matched,
746+
// we solve the ambiguity by taking the one with the highest precedence but only
747+
// considering the actual arguments passed, ignoring the optional arguments for
748+
// which the default values were used
749+
var bestMatch = matchesTouse
750+
.GroupBy(x => x.KwargsMatched)
751+
.OrderByDescending(x => x.Key)
752+
.First()
753+
.MinBy(x => GetMatchedArgumentsPrecedence(x.MethodInformation, pyArgCount, kwArgDict?.Keys));
754+
723755
var margs = bestMatch.ManagedArgs;
724756
var outs = bestMatch.Outs;
725757
var mi = bestMatch.Method;
@@ -1084,14 +1116,15 @@ private readonly struct MatchedMethod
10841116
public int KwargsMatched { get; }
10851117
public object?[] ManagedArgs { get; }
10861118
public int Outs { get; }
1087-
public MethodBase Method { get; }
1119+
public MethodInformation MethodInformation { get; }
1120+
public MethodBase Method => MethodInformation.MethodBase;
10881121

1089-
public MatchedMethod(int kwargsMatched, object?[] margs, int outs, MethodBase mb)
1122+
public MatchedMethod(int kwargsMatched, object?[] margs, int outs, MethodInformation methodInformation)
10901123
{
10911124
KwargsMatched = kwargsMatched;
10921125
ManagedArgs = margs;
10931126
Outs = outs;
1094-
Method = mb;
1127+
MethodInformation = methodInformation;
10951128
}
10961129
}
10971130

src/runtime/Properties/AssemblyInfo.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
[assembly: InternalsVisibleTo("Python.EmbeddingTest, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")]
55
[assembly: InternalsVisibleTo("Python.Test, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")]
66

7-
[assembly: AssemblyVersion("2.0.39")]
8-
[assembly: AssemblyFileVersion("2.0.39")]
7+
[assembly: AssemblyVersion("2.0.40")]
8+
[assembly: AssemblyFileVersion("2.0.40")]

src/runtime/Python.Runtime.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<RootNamespace>Python.Runtime</RootNamespace>
66
<AssemblyName>Python.Runtime</AssemblyName>
77
<PackageId>QuantConnect.pythonnet</PackageId>
8-
<Version>2.0.39</Version>
8+
<Version>2.0.40</Version>
99
<GenerateAssemblyInfo>false</GenerateAssemblyInfo>
1010
<PackageLicenseFile>LICENSE</PackageLicenseFile>
1111
<RepositoryUrl>https://github.com/pythonnet/pythonnet</RepositoryUrl>

0 commit comments

Comments
 (0)