Skip to content

Commit 6757e1f

Browse files
committed
feat: bind snake case name methods named parameters along with original method .net to python
1 parent 5ddc78c commit 6757e1f

File tree

4 files changed

+159
-13
lines changed

4 files changed

+159
-13
lines changed

src/embed_tests/ClassManagerTests.cs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
24

35
using NUnit.Framework;
46

@@ -66,6 +68,20 @@ public static int AddNumbersAndGetHalf_Static(int a, int b)
6668
{
6769
return (a + b) / 2;
6870
}
71+
72+
public string JoinToString(string thisIsAStringParameter,
73+
char thisIsACharParameter,
74+
int thisIsAnIntParameter,
75+
float thisIsAFloatParameter,
76+
double thisIsADoubleParameter,
77+
decimal thisIsADecimalParameter,
78+
bool thisIsABoolParameter,
79+
DateTime thisIsADateTimeParameter)
80+
{
81+
// Join all parameters into a single string separated by "-"
82+
return string.Join("-", thisIsAStringParameter, thisIsACharParameter, thisIsAnIntParameter, thisIsAFloatParameter,
83+
thisIsADoubleParameter, thisIsADecimalParameter, thisIsABoolParameter, string.Format("{0:MMddyyyy}", thisIsADateTimeParameter));
84+
}
6985
}
7086

7187
[TestCase("AddNumbersAndGetHalf", "add_numbers_and_get_half")]
@@ -298,6 +314,92 @@ def RemoveEventHandler(handler):
298314
}
299315
}
300316

317+
private static IEnumerable<TestCaseData> SnakeCasedNamedArgsTestCases
318+
{
319+
get
320+
{
321+
var stringParam = "string";
322+
var charParam = 'c';
323+
var intParam = 1;
324+
var floatParam = 2.0f;
325+
var doubleParam = 3.0;
326+
var decimalParam = 4.0m;
327+
var boolParam = true;
328+
var dateTimeParam = new DateTime(2013, 01, 05);
329+
330+
// 1. All kwargs:
331+
332+
// 1.1. Original method name:
333+
var args = Array.Empty<object>();
334+
var namedArgs = new Dictionary<string, object>()
335+
{
336+
{ "thisIsAStringParameter", stringParam },
337+
{ "thisIsACharParameter", charParam },
338+
{ "thisIsAnIntParameter", intParam },
339+
{ "thisIsAFloatParameter", floatParam },
340+
{ "thisIsADoubleParameter", doubleParam },
341+
{ "thisIsADecimalParameter", decimalParam },
342+
{ "thisIsABoolParameter", boolParam },
343+
{ "thisIsADateTimeParameter", dateTimeParam }
344+
};
345+
yield return new TestCaseData("JoinToString", args, namedArgs);
346+
347+
// 1.2. Snake-cased method name:
348+
namedArgs = new Dictionary<string, object>()
349+
{
350+
{ "this_is_a_string_parameter", stringParam },
351+
{ "this_is_a_char_parameter", charParam },
352+
{ "this_is_an_int_parameter", intParam },
353+
{ "this_is_a_float_parameter", floatParam },
354+
{ "this_is_a_double_parameter", doubleParam },
355+
{ "this_is_a_decimal_parameter", decimalParam },
356+
{ "this_is_a_bool_parameter", boolParam },
357+
{ "this_is_a_date_time_parameter", dateTimeParam }
358+
};
359+
yield return new TestCaseData("join_to_string", args, namedArgs);
360+
361+
// 2. Some args and some kwargs:
362+
363+
// 2.1. Original method name:
364+
args = new object[] { stringParam, charParam, intParam, floatParam };
365+
namedArgs = new Dictionary<string, object>()
366+
{
367+
{ "thisIsADoubleParameter", doubleParam },
368+
{ "thisIsADecimalParameter", decimalParam },
369+
{ "thisIsABoolParameter", boolParam },
370+
{ "thisIsADateTimeParameter", dateTimeParam }
371+
};
372+
yield return new TestCaseData("JoinToString", args, namedArgs);
373+
374+
// 2.2. Snake-cased method name:
375+
namedArgs = new Dictionary<string, object>()
376+
{
377+
{ "this_is_a_double_parameter", doubleParam },
378+
{ "this_is_a_decimal_parameter", decimalParam },
379+
{ "this_is_a_bool_parameter", boolParam },
380+
{ "this_is_a_date_time_parameter", dateTimeParam }
381+
};
382+
yield return new TestCaseData("join_to_string", args, namedArgs);
383+
}
384+
}
385+
386+
[TestCaseSource(nameof(SnakeCasedNamedArgsTestCases))]
387+
public void CanCallSnakeCasedMethodWithSnakeCasedNamedArguments(string methodName, object[] args, Dictionary<string, object> namedArgs)
388+
{
389+
using var obj = new SnakeCaseNamesTesClass().ToPython();
390+
391+
var pyArgs = args.Select(a => a.ToPython()).ToArray();
392+
using var pyNamedArgs = new PyDict();
393+
foreach (var (key, value) in namedArgs)
394+
{
395+
pyNamedArgs[key] = value.ToPython();
396+
}
397+
398+
var result = obj.InvokeMethod(methodName, pyArgs, pyNamedArgs).As<string>();
399+
400+
Assert.AreEqual("string-c-1-2-3-4.0-True-01052013", result);
401+
}
402+
301403
#endregion
302404
}
303405

src/runtime/ClassManager.cs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ internal static bool ShouldBindEvent(EventInfo ei)
336336
private static ClassInfo GetClassInfo(Type type, ClassBase impl)
337337
{
338338
var ci = new ClassInfo();
339-
var methods = new Dictionary<string, List<MethodBase>>();
339+
var methods = new Dictionary<string, MethodOverloads>();
340340
MethodInfo meth;
341341
ExtensionType ob;
342342
string name;
@@ -450,7 +450,7 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
450450

451451
if (!methods.TryGetValue(name, out var methodList))
452452
{
453-
methodList = methods[name] = new List<MethodBase>();
453+
methodList = methods[name] = new MethodOverloads(true);
454454
}
455455
methodList.Add(meth);
456456

@@ -459,7 +459,7 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
459459
name = name.ToSnakeCase();
460460
if (!methods.TryGetValue(name, out methodList))
461461
{
462-
methodList = methods[name] = new List<MethodBase>();
462+
methodList = methods[name] = new MethodOverloads(false);
463463
}
464464
methodList.Add(meth);
465465
}
@@ -475,7 +475,7 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
475475
name = "__init__";
476476
if (!methods.TryGetValue(name, out methodList))
477477
{
478-
methodList = methods[name] = new List<MethodBase>();
478+
methodList = methods[name] = new MethodOverloads(true);
479479
}
480480
methodList.Add(ctor);
481481
continue;
@@ -550,9 +550,9 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
550550
foreach (var iter in methods)
551551
{
552552
name = iter.Key;
553-
var mlist = iter.Value.ToArray();
553+
var mlist = iter.Value.Methods.ToArray();
554554

555-
ob = new MethodObject(type, name, mlist);
555+
ob = new MethodObject(type, name, mlist, isOriginal: iter.Value.IsOriginal);
556556
ci.members[name] = ob.AllocObject();
557557
if (mlist.Any(OperatorMethod.IsOperatorMethod))
558558
{
@@ -604,6 +604,24 @@ internal ClassInfo()
604604
indexer = null;
605605
}
606606
}
607+
608+
private class MethodOverloads
609+
{
610+
public List<MethodBase> Methods { get; }
611+
612+
public bool IsOriginal { get; }
613+
614+
public MethodOverloads(bool original = true)
615+
{
616+
Methods = new List<MethodBase>();
617+
IsOriginal = original;
618+
}
619+
620+
public void Add(MethodBase method)
621+
{
622+
Methods.Add(method);
623+
}
624+
}
607625
}
608626

609627
}

src/runtime/MethodBinder.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@ public int Count
4040
}
4141

4242
internal void AddMethod(MethodBase m)
43+
{
44+
AddMethod(m, true);
45+
}
46+
47+
internal void AddMethod(MethodBase m, bool isOriginal)
4348
{
4449
// we added a new method so we have to re sort the method list
4550
init = false;
46-
list.Add(new MethodInformation(m, m.GetParameters()));
51+
list.Add(new MethodInformation(m, m.GetParameters(), isOriginal));
4752
}
4853

4954
/// <summary>
@@ -118,7 +123,7 @@ internal static MethodInfo[] MatchParameters(MethodBase[] mi, Type[] tp)
118123
return result.ToArray();
119124
}
120125

121-
// Given a generic method and the argsTypes previously matched with it,
126+
// Given a generic method and the argsTypes previously matched with it,
122127
// generate the matching method
123128
internal static MethodInfo ResolveGenericMethod(MethodInfo method, Object[] args)
124129
{
@@ -474,11 +479,15 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
474479

475480
// Must be done after IsOperator section
476481
int clrArgCount = pi.Length;
482+
var parametersSnakeCasedNames = kwArgDict == null || methodInformation.IsOriginal
483+
? null
484+
: pi.Select(p => p.Name.ToSnakeCase()).ToArray();
477485

478486
if (CheckMethodArgumentsMatch(clrArgCount,
479487
pyArgCount,
480488
kwArgDict,
481489
pi,
490+
parametersSnakeCasedNames,
482491
out bool paramsArray,
483492
out ArrayList defaultArgList))
484493
{
@@ -497,7 +506,12 @@ internal Binding Bind(BorrowedReference inst, BorrowedReference args, BorrowedRe
497506
object arg; // Python -> Clr argument
498507

499508
// Check our KWargs for this parameter
500-
bool hasNamedParam = kwArgDict == null ? false : kwArgDict.TryGetValue(parameter.Name, out tempPyObject);
509+
var hasNamedParam = false;
510+
if (kwArgDict != null)
511+
{
512+
var paramName = methodInformation.IsOriginal ? parameter.Name : parametersSnakeCasedNames[paramIndex];
513+
hasNamedParam = kwArgDict.TryGetValue(paramName, out tempPyObject);
514+
}
501515
if(tempPyObject != null)
502516
{
503517
op = tempPyObject;
@@ -766,6 +780,7 @@ private bool CheckMethodArgumentsMatch(int clrArgCount,
766780
int pyArgCount,
767781
Dictionary<string, PyObject> kwargDict,
768782
ParameterInfo[] parameterInfo,
783+
string[] parametersSnakeCasedNames,
769784
out bool paramsArray,
770785
out ArrayList defaultArgList)
771786
{
@@ -788,7 +803,9 @@ private bool CheckMethodArgumentsMatch(int clrArgCount,
788803
{
789804
// If the method doesn't have all of these kw args, it is not a match
790805
// Otherwise just continue on to see if it is a match
791-
if (!kwargDict.All(x => parameterInfo.Any(pi => x.Key == pi.Name)))
806+
if (!kwargDict.All(x => parametersSnakeCasedNames == null
807+
? parameterInfo.Any(pi => x.Key == pi.Name)
808+
: parametersSnakeCasedNames.Any(paramName => x.Key == paramName)))
792809
{
793810
return false;
794811
}
@@ -808,7 +825,7 @@ private bool CheckMethodArgumentsMatch(int clrArgCount,
808825
defaultArgList = new ArrayList();
809826
for (var v = pyArgCount; v < clrArgCount && match; v++)
810827
{
811-
if (kwargDict != null && kwargDict.ContainsKey(parameterInfo[v].Name))
828+
if (kwargDict != null && kwargDict.ContainsKey(parametersSnakeCasedNames == null ? parameterInfo[v].Name : parametersSnakeCasedNames[v]))
812829
{
813830
// we have a keyword argument for this parameter,
814831
// no need to check for a default parameter, but put a null
@@ -977,10 +994,18 @@ internal class MethodInformation
977994

978995
public ParameterInfo[] ParameterInfo { get; }
979996

997+
public bool IsOriginal { get; }
998+
980999
public MethodInformation(MethodBase methodBase, ParameterInfo[] parameterInfo)
1000+
: this(methodBase, parameterInfo, true)
1001+
{
1002+
}
1003+
1004+
public MethodInformation(MethodBase methodBase, ParameterInfo[] parameterInfo, bool isOriginal)
9811005
{
9821006
MethodBase = methodBase;
9831007
ParameterInfo = parameterInfo;
1008+
IsOriginal = isOriginal;
9841009
}
9851010

9861011
public override string ToString()

src/runtime/Types/MethodObject.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ internal class MethodObject : ExtensionType
2828
internal PyString? doc;
2929
internal MaybeType type;
3030

31-
public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_threads = MethodBinder.DefaultAllowThreads)
31+
public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_threads = MethodBinder.DefaultAllowThreads,
32+
bool isOriginal = true)
3233
{
3334
this.type = type;
3435
this.name = name;
@@ -37,7 +38,7 @@ public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_t
3738
foreach (MethodBase item in info)
3839
{
3940
this.infoList.Add(item);
40-
binder.AddMethod(item);
41+
binder.AddMethod(item, isOriginal);
4142
if (item.IsStatic)
4243
{
4344
this.is_static = true;

0 commit comments

Comments
 (0)