diff --git a/Wrappers/Python/cil/optimisation/functions/OperatorCompositionFunction.py b/Wrappers/Python/cil/optimisation/functions/OperatorCompositionFunction.py index 7f8a3b2f1e..68be8b18cd 100644 --- a/Wrappers/Python/cil/optimisation/functions/OperatorCompositionFunction.py +++ b/Wrappers/Python/cil/optimisation/functions/OperatorCompositionFunction.py @@ -77,3 +77,30 @@ def gradient(self, x, out=None): self.function.gradient(tmp, out=tmp) return self.operator.adjoint(tmp, out=out) + def proximal(self, x, tau, out=None): + + if not self.operator.is_orthogonal(): + raise ValueError("Semi-orthogonality is required for operator.") + + tmp = self.operator.range_geometry().allocate() + self.operator.direct(x, out=tmp) + + if out is None: + return x + (1./self.operator.orthogonal_scalar)*self.operator.adjoint(self.function.proximal(tmp, tau=tau*self.operator.orthogonal_scalar) - tmp) + else: + self.operator.adjoint(self.function.proximal(tmp, tau=tau*self.operator.orthogonal_scalar) - tmp, out=out) + x.sapyb(1., out, 1./(self.operator.orthogonal_scalar), out=out) + + + + + + + + + + + + + + diff --git a/Wrappers/Python/cil/optimisation/operators/WaveletOperator.py b/Wrappers/Python/cil/optimisation/operators/WaveletOperator.py index 8e8264c58e..11d961e3cf 100644 --- a/Wrappers/Python/cil/optimisation/operators/WaveletOperator.py +++ b/Wrappers/Python/cil/optimisation/operators/WaveletOperator.py @@ -154,6 +154,9 @@ def __init__(self, domain_geometry, raise AttributeError( f"Size of the range geometry is {range_geometry.shape} but the size of the wavelet coefficient array must be {tuple(range_shape)}.") + # semi-orthogonality constant + self.orthogonal_scalar = kwargs.get('orthogonal_scalar', 1.0) + super().__init__(domain_geometry=domain_geometry, range_geometry=range_geometry) def _shape2slice(self):