1+ import math
12from typing import List
23
4+ import credmark .cmf .model
35from credmark .dto import DTOField , IterableListGenericDTO , PrivateAttr
46
5- from .position import Position , PositionWithPrice
7+ from .adt import Maybe , Some
8+ from .position import Position
9+ from .price import PriceWithQuote
610
711
812class Portfolio (IterableListGenericDTO [Position ]):
913 positions : List [Position ] = DTOField (
1014 default = [], description = 'List of positions' )
1115 _iterator : str = PrivateAttr ('positions' )
1216
13- def get_value (self , price_model = 'price.quote' , block_number = None , quote = None ):
17+ def get_value (self , block_number = None , quote = None ):
1418 """
1519 Returns:
1620 The value of the portfolio using the price_model.
1721
1822 Raises:
1923 ModelDataError: if no pools available for a position's price data.
2024 """
21- if len (self .positions ) > 0 :
22- total = 0
23- for pos in self .positions :
24- total += pos .get_value (price_model , block_number = block_number , quote = quote )
25- return total
26- return 0
25+ non_zero_positions = [position for position in self .positions
26+ if not math .isclose (position .amount , 0 )]
27+
28+ if len (non_zero_positions ) == 0 :
29+ return 0
30+
31+ total = 0
32+ context = credmark .cmf .model .ModelContext .current_context ()
33+ if block_number is None :
34+ block_number = context .block_number
35+
36+ pqs_maybe = context .run_model (
37+ slug = 'price.quote-multiple-maybe' ,
38+ input = Some (some = [
39+ {'base' : p .asset .address } if quote is None
40+ else {'base' : p .asset .address , 'quote' : quote }
41+ for p in non_zero_positions
42+ ]),
43+ block_number = block_number ,
44+ return_type = Some [Maybe [PriceWithQuote ]],
45+ )
46+ for price_maybe , position in zip (pqs_maybe .some , non_zero_positions ):
47+ if price_maybe .just is not None :
48+ total += position .amount * price_maybe .just .price
49+
50+ return total
2751
2852 class Config :
2953 schema_extra : dict = {
@@ -48,47 +72,3 @@ def merge(cls, port1: "Portfolio", port2: "Portfolio"):
4872 positions [pos_key ].amount += pos .amount
4973
5074 return cls (positions = list (positions .values ()))
51-
52-
53- class PortfolioWithPrice (IterableListGenericDTO [PositionWithPrice ]):
54- positions : List [PositionWithPrice ] = DTOField (
55- default = [], description = 'List of positions' )
56- _iterator : str = PrivateAttr ('positions' )
57-
58- def get_value (self , price_model = 'price.quote' , block_number = None , quote = None ):
59- """
60- Returns:
61- The value of the portfolio using the price_model.
62-
63- Raises:
64- ModelDataError: if no pools available for a position's price data.
65- """
66- if len (self .positions ) > 0 :
67- return sum (pos .get_value (price_model , block_number = block_number , quote = quote )
68- for pos in self .positions ) # pylint:disable=not-an-iterable
69- return 0
70-
71- class Config :
72- schema_extra : dict = {
73- 'examples' : [{'positions' : [exp ]}
74- for exp in PositionWithPrice .Config .schema_extra ['examples' ]]
75- }
76-
77- @classmethod
78- def merge (cls , port1 : "Portfolio" , port2 : "Portfolio" ):
79- positions = {}
80- for pos in port1 :
81- pos_key = str (pos .asset .address )
82- if positions .get (pos_key , None ) is None :
83- positions [pos_key ] = pos .copy ()
84- else :
85- positions [pos_key ].amount += pos .amount
86-
87- for pos in port2 :
88- pos_key = str (pos .asset .address )
89- if positions .get (pos_key , None ) is None :
90- positions [pos_key ] = pos .copy ()
91- else :
92- positions [pos_key ].amount += pos .amount
93-
94- return cls (positions = list (positions .values ()))
0 commit comments