Skip to content

Commit a8fbf5b

Browse files
nhuetducoffeM
authored andcommitted
Revert "Replace direct call to get_upper_box by BoxDomain().get_upper"
This reverts commit e98b3c9. + add relevant import + remove unused imports
1 parent 71d235a commit a8fbf5b

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

src/decomon/backward_layers/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from decomon.core import (
99
BoxDomain,
1010
ForwardMode,
11-
GridDomain,
1211
InputsOutputsSpec,
1312
PerturbationDomain,
14-
Slope,
1513
get_affine,
1614
get_ibp,
15+
get_lower_box,
16+
get_upper_box,
1717
)
1818
from decomon.layers.utils import sort
19-
from decomon.utils import get_linear_hull_relu, maximum, minus, relu_, subtract
19+
from decomon.utils import maximum, minus, relu_, subtract
2020

2121

2222
def backward_add(
@@ -48,7 +48,6 @@ def backward_add(
4848
if perturbation_domain is None:
4949
perturbation_domain = BoxDomain()
5050
mode = ForwardMode(mode)
51-
affine = get_affine(mode)
5251
op_flat = Flatten(dtype=K.floatx()) # pas terrible a revoir
5352
inputs_outputs_spec = InputsOutputsSpec(dc_decomp=dc_decomp, mode=mode, perturbation_domain=perturbation_domain)
5453
x_0, u_c_0, w_u_0, b_u_0, l_c_0, w_l_0, b_l_0, h_0, g_0 = inputs_outputs_spec.get_fullinputs_from_inputsformode(
@@ -63,12 +62,10 @@ def backward_add(
6362
l_c_0 = op_flat(l_c_0)
6463
l_c_1 = op_flat(l_c_1)
6564

66-
x_0 = K.concatenate([K.expand_dims(l_c_0, 1), K.expand_dims(u_c_0, 1)], 1)
67-
x_1 = K.concatenate([K.expand_dims(l_c_1, 1), K.expand_dims(u_c_1, 1)], 1)
68-
upper_0 = BoxDomain().get_upper(x_0, w_u_out, b_u_out)
69-
upper_1 = BoxDomain().get_upper(x_1, w_u_out, b_u_out)
70-
lower_0 = BoxDomain().get_lower(x_0, w_l_out, b_l_out)
71-
lower_1 = BoxDomain().get_lower(x_1, w_l_out, b_l_out)
65+
upper_0 = get_upper_box(l_c_0, u_c_0, w_u_out, b_u_out)
66+
upper_1 = get_upper_box(l_c_1, u_c_1, w_u_out, b_u_out)
67+
lower_0 = get_lower_box(l_c_0, u_c_0, w_l_out, b_l_out)
68+
lower_1 = get_lower_box(l_c_1, u_c_1, w_l_out, b_l_out)
7269

7370
w_u_out_0 = w_u_out
7471
b_u_out_0 = upper_1

0 commit comments

Comments
 (0)