Skip to content

Commit c54095f

Browse files
committed
Centralize Python/C# model detection logic
1 parent a78004a commit c54095f

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

Common/Securities/Security.cs

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
using QuantConnect.Data.Fundamental;
3434
using QuantConnect.Interfaces;
3535
using QuantConnect.Data.Shortable;
36+
using QuantConnect.Util;
3637

3738
namespace QuantConnect.Securities
3839
{
@@ -725,7 +726,7 @@ public void SetFeeModel(IFeeModel feelModel)
725726
/// <param name="feelModel">Model that represents a fee model</param>
726727
public void SetFeeModel(PyObject feelModel)
727728
{
728-
FeeModel = new FeeModelPythonWrapper(feelModel);
729+
FeeModel = PythonUtil.CreateModelOrWrapper<IFeeModel, FeeModelPythonWrapper>(feelModel);
729730
}
730731

731732
/// <summary>
@@ -743,7 +744,7 @@ public void SetFillModel(IFillModel fillModel)
743744
/// <param name="fillModel">Model that represents a fill model</param>
744745
public void SetFillModel(PyObject fillModel)
745746
{
746-
FillModel = new FillModelPythonWrapper(fillModel);
747+
FillModel = PythonUtil.CreateModelOrWrapper<IFillModel, FillModelPythonWrapper>(fillModel);
747748
}
748749

749750
/// <summary>
@@ -761,7 +762,7 @@ public void SetSettlementModel(ISettlementModel settlementModel)
761762
/// <param name="settlementModel">Model that represents a settlement model</param>
762763
public void SetSettlementModel(PyObject settlementModel)
763764
{
764-
SettlementModel = new SettlementModelPythonWrapper(settlementModel);
765+
SettlementModel = PythonUtil.CreateModelOrWrapper<ISettlementModel, SettlementModelPythonWrapper>(settlementModel);
765766
}
766767

767768
/// <summary>
@@ -779,7 +780,7 @@ public void SetSlippageModel(ISlippageModel slippageModel)
779780
/// <param name="slippageModel">Model that represents a slippage model</param>
780781
public void SetSlippageModel(PyObject slippageModel)
781782
{
782-
SlippageModel = new SlippageModelPythonWrapper(slippageModel);
783+
SlippageModel = PythonUtil.CreateModelOrWrapper<ISlippageModel, SlippageModelPythonWrapper>(slippageModel);
783784
}
784785

785786
/// <summary>
@@ -797,7 +798,7 @@ public void SetVolatilityModel(IVolatilityModel volatilityModel)
797798
/// <param name="volatilityModel">Model that represents a volatility model</param>
798799
public void SetVolatilityModel(PyObject volatilityModel)
799800
{
800-
VolatilityModel = new VolatilityModelPythonWrapper(volatilityModel);
801+
VolatilityModel = PythonUtil.CreateModelOrWrapper<IVolatilityModel, VolatilityModelPythonWrapper>(volatilityModel);
801802
}
802803

803804
/// <summary>
@@ -815,7 +816,7 @@ public void SetBuyingPowerModel(IBuyingPowerModel buyingPowerModel)
815816
/// <param name="pyObject">Model that represents a security's model of buying power</param>
816817
public void SetBuyingPowerModel(PyObject pyObject)
817818
{
818-
SetBuyingPowerModel(new BuyingPowerModelPythonWrapper(pyObject));
819+
BuyingPowerModel = PythonUtil.CreateModelOrWrapper<IBuyingPowerModel, BuyingPowerModelPythonWrapper>(pyObject);
819820
}
820821

821822
/// <summary>
@@ -833,7 +834,7 @@ public void SetMarginInterestRateModel(IMarginInterestRateModel marginInterestRa
833834
/// <param name="pyObject">Model that represents a security's model of margin interest rate</param>
834835
public void SetMarginInterestRateModel(PyObject pyObject)
835836
{
836-
SetMarginInterestRateModel(new MarginInterestRateModelPythonWrapper(pyObject));
837+
MarginInterestRateModel = PythonUtil.CreateModelOrWrapper<IMarginInterestRateModel, MarginInterestRateModelPythonWrapper>(pyObject);
837838
}
838839

839840
/// <summary>
@@ -851,7 +852,7 @@ public void SetMarginModel(IBuyingPowerModel marginModel)
851852
/// <param name="pyObject">Model that represents a security's model of buying power</param>
852853
public void SetMarginModel(PyObject pyObject)
853854
{
854-
SetMarginModel(new BuyingPowerModelPythonWrapper(pyObject));
855+
MarginModel = PythonUtil.CreateModelOrWrapper<IBuyingPowerModel, BuyingPowerModelPythonWrapper>(pyObject);
855856
}
856857

857858
/// <summary>
@@ -860,21 +861,7 @@ public void SetMarginModel(PyObject pyObject)
860861
/// <param name="pyObject">Python class that represents a custom shortable provider</param>
861862
public void SetShortableProvider(PyObject pyObject)
862863
{
863-
if (pyObject.TryConvert<IShortableProvider>(out var shortableProvider))
864-
{
865-
SetShortableProvider(shortableProvider);
866-
}
867-
else if (Extensions.TryConvert<IShortableProvider>(pyObject, out _, allowPythonDerivative: true))
868-
{
869-
SetShortableProvider(new ShortableProviderPythonWrapper(pyObject));
870-
}
871-
else
872-
{
873-
using (Py.GIL())
874-
{
875-
throw new Exception($"SetShortableProvider: {pyObject.Repr()} is not a valid argument");
876-
}
877-
}
864+
ShortableProvider = PythonUtil.CreateModelOrWrapper<IShortableProvider, ShortableProviderPythonWrapper>(pyObject);
878865
}
879866

880867
/// <summary>
@@ -893,21 +880,7 @@ public void SetShortableProvider(IShortableProvider shortableProvider)
893880
/// <exception cref="ArgumentException"></exception>
894881
public void SetDataFilter(PyObject pyObject)
895882
{
896-
if (pyObject.TryConvert<ISecurityDataFilter>(out var dataFilter))
897-
{
898-
SetDataFilter(dataFilter);
899-
}
900-
else if (Extensions.TryConvert<ISecurityDataFilter>(pyObject, out _, allowPythonDerivative: true))
901-
{
902-
SetDataFilter(new SecurityDataFilterPythonWrapper(pyObject));
903-
}
904-
else
905-
{
906-
using (Py.GIL())
907-
{
908-
throw new ArgumentException($"SetDataFilter: {pyObject.Repr()} is not a valid argument");
909-
}
910-
}
883+
DataFilter = PythonUtil.CreateModelOrWrapper<ISecurityDataFilter, SecurityDataFilterPythonWrapper>(pyObject);
911884
}
912885

913886
/// <summary>

Common/Util/PythonUtil.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,5 +361,35 @@ public static IEnumerable<Symbol> ConvertToSymbols(PyObject input)
361361
}
362362
return symbolsList;
363363
}
364+
365+
/// <summary>
366+
/// Creates either a pure C# model instance or a Python wrapper based on the input PyObject
367+
/// </summary>
368+
/// <typeparam name="TInterface">The interface type expected</typeparam>
369+
/// <typeparam name="TWrapper">The Python wrapper type for TInterface</typeparam>
370+
/// <param name="pyObject">The Python object to convert</param>
371+
/// <returns>Either a pure C# instance or a Python wrapper implementing TInterface</returns>
372+
/// <exception cref="ArgumentException">Thrown when pyObject is not a valid TInterface</exception>
373+
public static TInterface CreateModelOrWrapper<TInterface, TWrapper>(PyObject pyObject)
374+
where TInterface : class
375+
where TWrapper : TInterface
376+
{
377+
using (Py.GIL())
378+
{
379+
if (pyObject.TryConvert<TInterface>(out var model))
380+
{
381+
// This object is pure C#
382+
return model;
383+
}
384+
385+
if (Extensions.TryConvert<TInterface>(pyObject, out _, allowPythonDerivative: true))
386+
{
387+
// Create the appropriate Python wrapper
388+
return (TInterface)Activator.CreateInstance(typeof(TWrapper), pyObject);
389+
}
390+
391+
throw new ArgumentException($"Invalid argument: {pyObject.Repr()} is not a valid {typeof(TInterface).Name}");
392+
}
393+
}
364394
}
365395
}

0 commit comments

Comments
 (0)