|
69 | 69 | log,
|
70 | 70 | log1mexp,
|
71 | 71 | log1p,
|
| 72 | + log1pexp, |
72 | 73 | lt,
|
73 | 74 | maximum,
|
74 | 75 | minimum,
|
@@ -1968,27 +1969,53 @@ def test_exp_softplus(self, exp_op):
|
1968 | 1969 | decimal=6,
|
1969 | 1970 | )
|
1970 | 1971 |
|
1971 |
| - def test_softplus_log(self): |
1972 |
| - # softplus(log(x)) -> log1p(x) |
| 1972 | + def test_log1pexp_log(self): |
| 1973 | + # log1pexp(log(x)) -> log1p(x) |
1973 | 1974 | data_valid = np.random.random((4, 3)).astype("float32") * 2
|
1974 | 1975 | data_valid[0, 0] = 0 # edge case
|
1975 | 1976 | data_invalid = data_valid - 2
|
1976 | 1977 |
|
1977 | 1978 | x = fmatrix()
|
1978 |
| - f = function([x], softplus(log(x)), mode=self.mode) |
1979 |
| - graph = f.maker.fgraph.toposort() |
1980 |
| - ops_graph = [ |
1981 |
| - node |
1982 |
| - for node in graph |
1983 |
| - if isinstance(node.op, Elemwise) |
1984 |
| - and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus) |
1985 |
| - ] |
1986 |
| - assert len(ops_graph) == 0 |
| 1979 | + f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace")) |
| 1980 | + assert equal_computations( |
| 1981 | + f.maker.fgraph.outputs, |
| 1982 | + [ |
| 1983 | + pt.switch( |
| 1984 | + x >= np.array([[0]], dtype=np.int8), |
| 1985 | + pt.log1p(x), |
| 1986 | + np.array([[np.nan]], dtype=np.float32), |
| 1987 | + ) |
| 1988 | + ], |
| 1989 | + ) |
1987 | 1990 |
|
1988 | 1991 | expected = np.log1p(data_valid)
|
1989 | 1992 | np.testing.assert_almost_equal(f(data_valid), expected)
|
1990 | 1993 | assert np.all(np.isnan(f(data_invalid)))
|
1991 | 1994 |
|
| 1995 | + def test_log1mexp_log(self): |
| 1996 | + # log1mexp(log(x)) -> log1p(-x) |
| 1997 | + data_valid = np.random.random((4, 3)).astype("float32") |
| 1998 | + data_valid[0, 0] = 0 # edge case |
| 1999 | + data_valid[0, 1] = 1 # another edge case |
| 2000 | + data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1]) |
| 2001 | + |
| 2002 | + x = fmatrix() |
| 2003 | + f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace")) |
| 2004 | + assert equal_computations( |
| 2005 | + f.maker.fgraph.outputs, |
| 2006 | + [ |
| 2007 | + pt.switch( |
| 2008 | + x >= np.array([[0]], dtype=np.int8), |
| 2009 | + pt.log1p(-x), |
| 2010 | + np.array([[np.nan]], dtype=np.float32), |
| 2011 | + ) |
| 2012 | + ], |
| 2013 | + ) |
| 2014 | + |
| 2015 | + expected = np.log1p(-data_valid) |
| 2016 | + np.testing.assert_almost_equal(f(data_valid), expected) |
| 2017 | + assert np.all(np.isnan(f(data_invalid))) |
| 2018 | + |
1992 | 2019 | @pytest.mark.parametrize(
|
1993 | 2020 | ["nested_expression", "expected_switches"],
|
1994 | 2021 | [
|
|
0 commit comments