|
26 | 26 | XOR, |
27 | 27 | Add, |
28 | 28 | IntDiv, |
| 29 | + Maximum, |
| 30 | + Minimum, |
29 | 31 | Mul, |
30 | | - ScalarMaximum, |
31 | | - ScalarMinimum, |
32 | 32 | Sub, |
33 | 33 | TrueDiv, |
34 | 34 | get_scalar_type, |
35 | | - scalar_maximum, |
| 35 | + maximum, |
36 | 36 | ) |
37 | 37 | from pytensor.scalar.basic import add as add_as |
38 | 38 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
@@ -103,16 +103,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr): |
103 | 103 | return f"{res}[{idx}] //= {arr}" |
104 | 104 |
|
105 | 105 |
|
106 | | -@scalar_in_place_fn.register(ScalarMaximum) |
107 | | -def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr): |
| 106 | +@scalar_in_place_fn.register(Maximum) |
| 107 | +def scalar_in_place_fn_Maximum(op, idx, res, arr): |
108 | 108 | return f""" |
109 | 109 | if {res}[{idx}] < {arr}: |
110 | 110 | {res}[{idx}] = {arr} |
111 | 111 | """ |
112 | 112 |
|
113 | 113 |
|
114 | | -@scalar_in_place_fn.register(ScalarMinimum) |
115 | | -def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): |
| 114 | +@scalar_in_place_fn.register(Minimum) |
| 115 | +def scalar_in_place_fn_Minimum(op, idx, res, arr): |
116 | 116 | return f""" |
117 | 117 | if {res}[{idx}] > {arr}: |
118 | 118 | {res}[{idx}] = {arr} |
@@ -458,7 +458,7 @@ def numba_funcify_Softmax(op, node, **kwargs): |
458 | 458 | if axis is not None: |
459 | 459 | axis = normalize_axis_index(axis, x_at.ndim) |
460 | 460 | reduce_max_py = create_multiaxis_reducer( |
461 | | - scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
| 461 | + maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
462 | 462 | ) |
463 | 463 | reduce_sum_py = create_multiaxis_reducer( |
464 | 464 | add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True |
@@ -522,7 +522,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): |
522 | 522 | if axis is not None: |
523 | 523 | axis = normalize_axis_index(axis, x_at.ndim) |
524 | 524 | reduce_max_py = create_multiaxis_reducer( |
525 | | - scalar_maximum, |
| 525 | + maximum, |
526 | 526 | -np.inf, |
527 | 527 | (axis,), |
528 | 528 | x_at.ndim, |
|
0 commit comments