@@ -51,6 +51,10 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
5151 else :
5252 input_shape = conditions_shape
5353
54+ # Save input_shape and xz_shape for usage in get_build_config
55+ self ._input_shape = input_shape
56+ self ._xz_shape = xz_shape
57+
5458 # build the shared body network
5559 self .subnet .build (input_shape )
5660 body_output_shape = self .subnet .compute_output_shape (input_shape )
@@ -82,6 +86,62 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
8286 flat_key = f"{ score_key } ___{ head_key } "
8387 self .heads_flat [flat_key ] = head
8488
89+ def get_build_config (self ):
90+ build_config = {
91+ "conditions_shape" : self ._input_shape ,
92+ "xz_shape" : self ._xz_shape ,
93+ }
94+
95+ # Save names of head networks
96+ heads = {}
97+ for score_key in self .heads .keys ():
98+ heads [score_key ] = {}
99+ for head_key , head in self .heads [score_key ].items ():
100+ heads [score_key ][head_key ] = head .name
101+ # Alternatively, save full build config of head
102+ # heads[score_key][head_key] = head.get_build_config()
103+ # TODO: decide
104+
105+ build_config ["heads" ] = heads
106+
107+ return build_config
108+
109+ def build_from_config (self , config ):
110+ self .build (xz_shape = config ["xz_shape" ], conditions_shape = config ["conditions_shape" ])
111+
112+ for score_key in self .scores .keys ():
113+ for head_key , head in self .heads [score_key ].items ():
114+ head .name = config ["heads" ][score_key ][head_key ]
115+
116+ # Alternatively, do NOT call self.build, but rather imitate it using the build config of each head
117+ # This results in some code duplication with self.build and requires heads to be of a custom type.
118+ # TODO: decide
119+
120+ # input_shape = config["conditions_shape"]
121+ #
122+ # # Save input_shape for usage in get_build_config
123+ # self._input_shape = input_shape
124+ #
125+ # # build the shared body network
126+ # self.subnet.build(input_shape)
127+ #
128+ # # build head(s) for every scoring rule
129+ # self.heads = dict()
130+ # self.heads_flat = dict()
131+ #
132+ # for score_key in self.scores.keys():
133+ #
134+ # self.heads[score_key] = {}
135+ #
136+ # for head_key, head_config in config["heads"][score_key].items():
137+ # head = keras.Sequential()
138+ # head.build_from_config(head_config) # TODO: this method is not implemented yet
139+ # it would require the head to be a
140+ # custom object rather than a Sequential
141+ # self.heads[score_key][head_key] = head
142+ # flat_key = f"{score_key}___{head_key}"
143+ # self.heads_flat[flat_key] = head
144+
85145 def get_config (self ):
86146 base_config = super ().get_config ()
87147
0 commit comments