@@ -60,10 +60,20 @@ def from_config(cls, config):
6060 def get_head_shapes_from_target_shape (self , target_shape ):
6161 raise NotImplementedError
6262
63- def set_head_shapes_from_target_shape (self , target_shape ):
64- self .head_shapes = self .get_head_shapes_from_target_shape (target_shape )
65-
66- def get_subnet (self , key : str ):
63+ def get_subnet (self , key : str ) -> keras .Layer :
64+ """For a specified key, request a subnet to be used for projecting the shared condition embedding
65+ before reshaping to the heads output shape.
66+
67+ Parameters
68+ ----------
69+ key : str
70+ Name of head for which to request a link.
71+
72+ Returns
73+ -------
74+ link : keras.Layer
75+ Subnet projecting the shared condition embedding.
76+ """
6777 if key not in self .subnets .keys ():
6878 return keras .layers .Identity ()
6979 else :
@@ -77,15 +87,15 @@ def get_link(self, key: str):
7787 else :
7888 return self .links [key ]
7989
80- def get_head (self , key : str ):
90+ def get_head (self , key : str , shape : Shape ):
8191 subnet = self .get_subnet (key )
82- head_shape = self .head_shapes [key ]
83- dense = keras .layers .Dense (units = math .prod (head_shape ))
84- reshape = keras .layers .Reshape (target_shape = head_shape )
92+ dense = keras .layers .Dense (units = math .prod (shape ))
93+ reshape = keras .layers .Reshape (target_shape = shape )
8594 link = self .get_link (key )
8695 return keras .Sequential ([subnet , dense , reshape , link ])
8796
8897 def score (self , estimates : dict [str , Tensor ], targets : Tensor , weights : Tensor ) -> Tensor :
98+ """Scores a probabilistic estimate based of a distribution based on samples of that distribution."""
8999 raise NotImplementedError
90100
91101 def aggregate (self , scores : Tensor , weights : Tensor = None ):
@@ -114,7 +124,7 @@ def __init__(
114124 "k" : k ,
115125 }
116126
117- def get_head_shapes_from_target_shape (self , target_shape ):
127+ def get_head_shapes_from_target_shape (self , target_shape : Shape ):
118128 # keras.saving.load_model sometimes passes target_shape as a list.
119129 # This is why I force a conversion to tuple here.
120130 target_shape = tuple (target_shape )
@@ -180,7 +190,7 @@ def get_config(self):
180190 base_config = super ().get_config ()
181191 return base_config | self .config
182192
183- def get_head_shapes_from_target_shape (self , target_shape ):
193+ def get_head_shapes_from_target_shape (self , target_shape : Shape ):
184194 # keras.saving.load_model sometimes passes target_shape as a list.
185195 # This is why I force a conversion to tuple here.
186196 target_shape = tuple (target_shape )
@@ -240,7 +250,7 @@ def get_config(self):
240250 base_config = super ().get_config ()
241251 return base_config | self .config
242252
243- def get_head_shapes_from_target_shape (self , target_shape ) -> dict [str , Shape ]:
253+ def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
244254 self .D = target_shape [- 1 ]
245255 return dict (
246256 mean = (self .D ,),
0 commit comments