84
84
from pytensor .tensor .math import sum as at_sum
85
85
from pytensor .tensor .math import true_div
86
86
from pytensor .tensor .rewriting .basic import (
87
- broadcast_like ,
87
+ alloc_like ,
88
88
broadcasted_by ,
89
89
local_fill_sink ,
90
90
register_canonicalize ,
@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node):
1973
1973
new_out = cast (new_out , dtype = out .dtype )
1974
1974
# The ones could have forced a specific length
1975
1975
if not out .type .is_super (new_out .type ):
1976
- new_out = broadcast_like (new_out , out , fgraph )
1976
+ new_out = alloc_like (new_out , out , fgraph )
1977
1977
return [new_out ]
1978
1978
else :
1979
1979
return False
@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node):
1994
1994
if node .op == at_pow :
1995
1995
cst = get_constant (node .inputs [1 ])
1996
1996
if cst == 0 :
1997
- return [broadcast_like (1 , node .outputs [0 ], fgraph )]
1997
+ return [alloc_like (1 , node .outputs [0 ], fgraph )]
1998
1998
if cst == 1 :
1999
- return [broadcast_like (node .inputs [0 ], node .outputs [0 ], fgraph )]
1999
+ return [alloc_like (node .inputs [0 ], node .outputs [0 ], fgraph )]
2000
2000
else :
2001
2001
return False
2002
2002
@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node):
2033
2033
node .op .scalar_op , (aes .IntDiv , aes .TrueDiv )
2034
2034
):
2035
2035
if get_constant (node .inputs [0 ]) == 0 :
2036
- ret = broadcast_like (0 , node .outputs [0 ], fgraph )
2036
+ ret = alloc_like (0 , node .outputs [0 ], fgraph )
2037
2037
ret .tag .values_eq_approx = values_eq_approx_remove_nan
2038
2038
return [ret ]
2039
2039
@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node):
2184
2184
has_neg ^= True # toggles
2185
2185
elif y == 0.0 :
2186
2186
# if we find any zero, we just return right away
2187
- return [broadcast_like (0 , node .outputs [0 ], fgraph )]
2187
+ return [alloc_like (0 , node .outputs [0 ], fgraph )]
2188
2188
else :
2189
2189
new_inputs .append (inp )
2190
2190
@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node):
2209
2209
new_inputs = [m1 ] + new_inputs
2210
2210
rval = mul (* new_inputs )
2211
2211
2212
- return [broadcast_like (rval , node .outputs [0 ], fgraph )]
2212
+ return [alloc_like (rval , node .outputs [0 ], fgraph )]
2213
2213
else :
2214
2214
# there are no variable inputs to mul
2215
2215
# N.B. this could have been constant-folded...
2216
2216
if has_neg :
2217
- return [broadcast_like (- 1 , node .outputs [0 ], fgraph )]
2217
+ return [alloc_like (- 1 , node .outputs [0 ], fgraph )]
2218
2218
else :
2219
- return [broadcast_like (1 , node .outputs [0 ], fgraph )]
2219
+ return [alloc_like (1 , node .outputs [0 ], fgraph )]
2220
2220
2221
2221
2222
2222
@register_specialize
0 commit comments