2
2
3
3
from textwrap import indent
4
4
5
- import numpy as np
6
-
7
5
from pytensor .gradient import DisconnectedType
8
- from pytensor .graph .basic import Apply , Variable
6
+ from pytensor .graph .basic import Apply , Constant , Variable
9
7
from pytensor .graph .replace import _vectorize_node
10
8
from pytensor .link .c .op import COp
11
9
from pytensor .link .c .params_type import ParamsType
12
10
from pytensor .link .c .type import Generic
13
- from pytensor .scalar .basic import ScalarType
11
+ from pytensor .scalar .basic import ScalarType , as_scalar
14
12
from pytensor .tensor .type import DenseTensorType
15
13
16
14
@@ -56,18 +54,6 @@ def __str__(self):
56
54
msg = self .msg
57
55
return f"{ name } {{raises={ exc_name } , msg='{ msg } '}}"
58
56
59
- def __eq__ (self , other ):
60
- if type (self ) is not type (other ):
61
- return False
62
-
63
- if self .msg == other .msg and self .exc_type == other .exc_type :
64
- return True
65
-
66
- return False
67
-
68
- def __hash__ (self ):
69
- return hash ((self .msg , self .exc_type ))
70
-
71
57
def make_node (self , value : Variable , * conds : Variable ):
72
58
"""
73
59
@@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable):
84
70
if not isinstance (value , Variable ):
85
71
value = pt .as_tensor_variable (value )
86
72
87
- conds = [
88
- pt .as_tensor_variable (c ) if not isinstance (c , Variable ) else c
89
- for c in conds
90
- ]
91
-
92
- assert all (c .type .ndim == 0 for c in conds )
73
+ conds = [as_scalar (c ) for c in conds ]
74
+ for i , cond in enumerate (conds ):
75
+ if cond .dtype != "bool" :
76
+ conds [i ] = cond .astype ("bool" )
93
77
94
78
return Apply (
95
79
self ,
@@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs):
101
85
(out ,) = outputs
102
86
val , * conds = inputs
103
87
out [0 ] = val
104
- if not np . all (conds ):
88
+ if not all (conds ):
105
89
raise self .exc_type (self .msg )
106
90
107
91
def grad (self , input , output_gradients ):
@@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props):
117
101
)
118
102
value_name , * cond_names = inames
119
103
out_name = onames [0 ]
120
- check = []
121
104
fail_code = props ["fail" ]
122
105
param_struct_name = props ["params" ]
123
106
msg = self .msg .replace ('"' , '\\ "' ).replace ("\n " , "\\ n" )
124
107
125
- for idx , cond_name in enumerate (cond_names ):
126
- if isinstance (node .inputs [0 ].type , DenseTensorType ):
127
- check .append (
128
- f"""
129
- if(PyObject_IsTrue((PyObject *){ cond_name } ) == 0) {{
130
- PyObject * exc_type = { param_struct_name } ->exc_type;
131
- Py_INCREF(exc_type);
132
- PyErr_SetString(exc_type, "{ msg } ");
133
- Py_XDECREF(exc_type);
134
- { indent (fail_code , " " * 4 )}
135
- }}
136
- """
137
- )
138
- else :
139
- check .append (
140
- f"""
141
- if({ cond_name } == 0) {{
142
- PyObject * exc_type = { param_struct_name } ->exc_type;
143
- Py_INCREF(exc_type);
144
- PyErr_SetString(exc_type, "{ msg } ");
145
- Py_XDECREF(exc_type);
146
- { indent (fail_code , " " * 4 )}
147
- }}
148
- """
149
- )
150
-
151
- check = "\n " .join (check )
108
+ all_conds = " && " .join (cond_names )
109
+ check = f"""
110
+ if(!({ all_conds } )) {{
111
+ PyObject * exc_type = { param_struct_name } ->exc_type;
112
+ Py_INCREF(exc_type);
113
+ PyErr_SetString(exc_type, "{ msg } ");
114
+ Py_XDECREF(exc_type);
115
+ { indent (fail_code , " " * 4 )}
116
+ }}
117
+ """
152
118
153
119
if isinstance (node .inputs [0 ].type , DenseTensorType ):
154
120
res = f"""
@@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props):
162
128
{ check }
163
129
{ out_name } = { value_name } ;
164
130
"""
165
- return res
131
+
132
+ return "\n " .join ((check , res ))
166
133
167
134
def c_code_cache_version (self ):
168
- return (1 , 1 )
135
+ return (2 , )
169
136
170
137
def infer_shape (self , fgraph , node , input_shapes ):
171
138
return [input_shapes [0 ]]
172
139
140
+ def do_constant_folding (self , fgraph , node ):
141
+ # Only constant-fold if the Assert does not fail
142
+ return all ((isinstance (c , Constant ) and bool (c .data )) for c in node .inputs [1 :])
143
+
173
144
174
145
class Assert (CheckAndRaise ):
175
146
"""Implements assertion in a computational graph.
0 commit comments