Skip to content

Commit 7c9efd4

Browse files
committed
drafting feature
1 parent 822ad89 commit 7c9efd4

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
137137
output_set = ops.concatenate([input_set, invariant_summary], axis=-1)
138138

139139
# Pass through final equivariant transform + residual
140-
output_set = input_set + self.equivariant_fc(output_set, training=training)
140+
out_fc = self.equivariant_fc(output_set, training=training)
141+
out_projected = self.out_fc_projector(out_fc)
142+
output_set = input_set + out_projected
143+
# output_set = input_set + self.equivariant_fc(output_set, training=training)
141144
if self.layer_norm is not None:
142145
output_set = self.layer_norm(output_set, training=training)
143146

bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
106106
"""
107107

108108
set_summary = self.inner_fc(input_set, training=training)
109+
set_summary = self.inner_projector(set_summary, training=training)
109110
set_summary = self.pooling_layer(set_summary, training=training)
110111
set_summary = self.outer_fc(set_summary, training=training)
112+
set_summary = self.outer_projector(set_summary, training=training)
111113
return set_summary
112114

113115
@sanitize_input_shape

0 commit comments

Comments
 (0)