@@ -25,7 +25,7 @@ def __init__(self, name: str = "", ordering_key: Optional[int] = None) -> None:
2525 :param name: name of node
2626 """
2727 self .uid = next (Node ._uid_generator )
28- self .name = name if name else str (self .uid )
28+ self .name = name or str (self .uid )
2929 self .ordering_key = ordering_key if ordering_key is not None else self .uid
3030 self .neighbors : dict [int , Node ] = {} # keys as senders uid
3131 self .received_messages : dict [int , Any ] = {} # keys as senders uid, values as messages
@@ -34,10 +34,7 @@ def register_neighbor(self, neighbor: Node) -> None:
3434 self .neighbors [neighbor .uid ] = neighbor
3535
3636 def __str__ (self ) -> str :
37- if self .name :
38- return self .name
39- else :
40- return str (self .uid )
37+ return self .name or str (self .uid )
4138
4239 def get_neighbors (self ) -> list [int ]:
4340 """
@@ -96,24 +93,27 @@ def phi(x: npt.NDArray[np.float_]) -> Any:
9693
9794class VNode (Node ):
9895 """Variable nodes in Tanner graph"""
99- def __init__ (self , channel_model : ChannelModel , ordering_key : int , name : str = "" ):
96+ def __init__ (self , ordering_key : int , name : str = "" , channel_model : Optional [ ChannelModel ] = None ):
10097 """
10198 :param channel_model: a function which receives channel outputs anr returns relevant message
10299 :param ordering_key: used to order nodes per their order in the parity check matrix
103100 :param name: optional name of node
104101 """
105- self .channel_model : ChannelModel = channel_model
106- self .channel_symbol : int = None # type: ignore # currently assuming hard channel symbols
102+ self .channel_model : Optional [ ChannelModel ] = channel_model
103+ self .channel_symbol : np . float_ = None # type: ignore
107104 self .channel_llr : np .float_ = None # type: ignore
108105 super ().__init__ (name , ordering_key )
109106
110- def initialize (self , channel_symbol : int ) -> None : # type: ignore
107+ def initialize (self , channel_symbol : np . float_ ) -> None : # type: ignore
111108 """
112109 clear received messages and initialize channel llr with channel bit
113110 :param channel_symbol: bit received from channel, currently assumes hard inputs.
114111 """
115112 self .channel_symbol = channel_symbol
116- self .channel_llr = self .channel_model (channel_symbol )
113+ if self .channel_model is not None :
114+ self .channel_llr = self .channel_model (channel_symbol )
115+ else :
116+ self .channel_llr = channel_symbol
117117 self .received_messages = {node_uid : 0 for node_uid in self .neighbors }
118118
119119 def message (self , requester_uid : int ) -> np .float_ :
@@ -127,4 +127,4 @@ def message(self, requester_uid: int) -> np.float_:
127127
128128 def estimate (self ) -> np .float_ :
129129 """provide a soft bit estimate"""
130- return self .channel_llr + np .sum (list (self .received_messages .values ())) # type: ignore
130+ return self .channel_llr + np .sum (list (self .received_messages .values ()), dtype = np . float_ ) # type: ignore
0 commit comments