@@ -148,38 +148,18 @@ async def call(self, inputs, training=False):
148148 )
149149 choice = decision .get ("choice" )
150150 outputs = []
151+ if self .inject_decision :
152+ inputs = await ops .concat (
153+ inputs ,
154+ decision ,
155+ name = self .name + "_inputs_with_decision" ,
156+ )
151157 for label , module in self .branches .items ():
152158 if label == choice :
153159 if module :
154- if self .inject_decision and self . return_decision :
160+ if self .return_decision :
155161 outputs .append (
156- await ops .concat (
157- decision ,
158- await module (
159- await ops .concat (
160- inputs ,
161- decision ,
162- name = self .name + "_inputs_with_decision" ,
163- ),
164- training = training ,
165- ),
166- name = self .name + "_with_decision" ,
167- )
168- )
169- elif self .inject_decision and not self .return_decision :
170- outputs .append (
171- await module (
172- await ops .concat (
173- inputs ,
174- decision ,
175- name = self .name + "_inputs_with_decision" ,
176- ),
177- training = training ,
178- )
179- )
180- elif not self .inject_decision and self .return_decision :
181- outputs .append (
182- await ops .concat (
162+ await ops .logical_and (
183163 decision ,
184164 await module (
185165 inputs ,
@@ -207,36 +187,16 @@ async def compute_output_spec(self, inputs, training=False):
207187 inputs ,
208188 training = training ,
209189 )
190+ if self .inject_decision :
191+ inputs = await ops .concat (
192+ inputs ,
193+ decision ,
194+ name = self .name + "_inputs_with_decision" ,
195+ )
210196 for module in self .branches .values ():
211- if self .inject_decision and self .return_decision :
212- outputs .append (
213- await ops .concat (
214- decision ,
215- await module (
216- await ops .concat (
217- inputs ,
218- decision ,
219- name = self .name + "_inputs_with_decision" ,
220- ),
221- training = training ,
222- ),
223- name = self .name + "_with_decision" ,
224- )
225- )
226- elif self .inject_decision and not self .return_decision :
227- outputs .append (
228- await module (
229- await ops .concat (
230- inputs ,
231- decision ,
232- name = self .name + "_inputs_with_decision" ,
233- ),
234- training = training ,
235- )
236- )
237- elif not self .inject_decision and self .return_decision :
197+ if self .return_decision :
238198 outputs .append (
239- await ops .concat (
199+ await ops .logical_and (
240200 decision ,
241201 await module (
242202 inputs ,
0 commit comments