Skip to content

Commit d1be796

Browse files
committed
Remove useless checks guaranteed by tracks
1 parent 4c40efa commit d1be796

File tree

1 file changed

+70
-74
lines changed

1 file changed

+70
-74
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 70 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -636,93 +636,89 @@ def local_inplace_ger(fgraph, node):
636636
@node_rewriter([gemm_no_inplace])
637637
def local_gemm_to_gemv(fgraph, node):
638638
"""GEMM acting on row or column matrices -> GEMV."""
639-
if node.op == gemm_no_inplace:
640-
z, a, x, y, b = node.inputs
641-
if z.broadcastable == x.broadcastable == (True, False):
642-
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
643-
new_out = [r.dimshuffle("x", 0)]
644-
elif z.broadcastable == y.broadcastable == (False, True):
645-
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
646-
new_out = [r.dimshuffle(0, "x")]
647-
else:
648-
return
649-
copy_stack_trace(node.outputs, new_out)
650-
return new_out
639+
z, a, x, y, b = node.inputs
640+
if z.broadcastable == x.broadcastable == (True, False):
641+
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
642+
new_out = [r.dimshuffle("x", 0)]
643+
elif z.broadcastable == y.broadcastable == (False, True):
644+
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
645+
new_out = [r.dimshuffle(0, "x")]
646+
else:
647+
return
648+
copy_stack_trace(node.outputs, new_out)
649+
return new_out
651650

652651

653652
@node_rewriter([gemm_no_inplace])
654653
def local_gemm_to_ger(fgraph, node):
655654
"""GEMM computing an outer-product -> GER."""
656-
if node.op == gemm_no_inplace:
657-
z, a, x, y, b = node.inputs
658-
if x.broadcastable[1] and y.broadcastable[0]:
659-
# x and y are both vectors so this might qualifies for a GER
660-
xv = x.dimshuffle(0)
661-
yv = y.dimshuffle(1)
662-
try:
663-
bval = ptb.get_underlying_scalar_constant_value(b)
664-
except NotScalarConstantError:
665-
# b isn't a constant, GEMM is doing useful pre-scaling
666-
return
667-
668-
if bval == 1: # best case a natural GER
669-
rval = ger(z, a, xv, yv)
670-
new_out = [rval]
671-
elif bval == 0: # GER on zeros_like should be faster than GEMM
672-
zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype)
673-
rval = ger(zeros, a, xv, yv)
674-
new_out = [rval]
675-
else:
676-
# if bval is another constant, then z is being usefully
677-
# pre-scaled and GER isn't really the right tool for the job.
678-
return
679-
copy_stack_trace(node.outputs, new_out)
680-
return new_out
681-
655+
z, a, x, y, b = node.inputs
656+
if x.broadcastable[1] and y.broadcastable[0]:
657+
# x and y are both vectors so this might qualifies for a GER
658+
xv = x.dimshuffle(0)
659+
yv = y.dimshuffle(1)
660+
try:
661+
bval = ptb.get_underlying_scalar_constant_value(b)
662+
except NotScalarConstantError:
663+
# b isn't a constant, GEMM is doing useful pre-scaling
664+
return
682665

683-
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
684-
# working
685-
@node_rewriter([_dot22])
686-
def local_dot22_to_ger_or_gemv(fgraph, node):
687-
"""dot22 computing an outer-product -> GER."""
688-
if node.op == _dot22:
689-
x, y = node.inputs
690-
xb = x.broadcastable
691-
yb = y.broadcastable
692-
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
693-
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
694-
if xb[1] and yb[0]:
695-
# x and y are both vectors so this might qualifies for a GER
696-
xv = x.dimshuffle(0)
697-
yv = y.dimshuffle(1)
698-
zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
699-
rval = ger(zeros, one, xv, yv)
666+
if bval == 1: # best case a natural GER
667+
rval = ger(z, a, xv, yv)
668+
new_out = [rval]
669+
elif bval == 0: # GER on zeros_like should be faster than GEMM
670+
zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype)
671+
rval = ger(zeros, a, xv, yv)
700672
new_out = [rval]
701-
elif xb[0] and yb[1]:
702-
# x and y are both vectors so this qualifies for a sdot / ddot
703-
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
704-
xv = x.dimshuffle(1)
705-
zeros = ptb.AllocEmpty(x.dtype)(1)
706-
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
707-
new_out = [rval.dimshuffle("x", 0)]
708-
elif xb[0] and not yb[0] and not yb[1]:
709-
# x is vector, y is matrix so try gemv
710-
xv = x.dimshuffle(1)
711-
zeros = ptb.AllocEmpty(x.dtype)(y.shape[1])
712-
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
713-
new_out = [rval.dimshuffle("x", 0)]
714-
elif not xb[0] and not xb[1] and yb[1]:
715-
# x is matrix, y is vector, try gemv
716-
yv = y.dimshuffle(0)
717-
zeros = ptb.AllocEmpty(x.dtype)(x.shape[0])
718-
rval = gemv_no_inplace(zeros, one, x, yv, zero)
719-
new_out = [rval.dimshuffle(0, "x")]
720673
else:
674+
# if bval is another constant, then z is being usefully
675+
# pre-scaled and GER isn't really the right tool for the job.
721676
return
722677
copy_stack_trace(node.outputs, new_out)
723678
return new_out
724679

725680

681+
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working
682+
@node_rewriter([_dot22])
683+
def local_dot22_to_ger_or_gemv(fgraph, node):
684+
"""dot22 computing an outer-product -> GER."""
685+
x, y = node.inputs
686+
xb = x.broadcastable
687+
yb = y.broadcastable
688+
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
689+
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
690+
if xb[1] and yb[0]:
691+
# x and y are both vectors so this might qualifies for a GER
692+
xv = x.dimshuffle(0)
693+
yv = y.dimshuffle(1)
694+
zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
695+
rval = ger(zeros, one, xv, yv)
696+
new_out = [rval]
697+
elif xb[0] and yb[1]:
698+
# x and y are both vectors so this qualifies for a sdot / ddot
699+
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
700+
xv = x.dimshuffle(1)
701+
zeros = ptb.AllocEmpty(x.dtype)(1)
702+
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
703+
new_out = [rval.dimshuffle("x", 0)]
704+
elif xb[0] and not yb[0] and not yb[1]:
705+
# x is vector, y is matrix so try gemv
706+
xv = x.dimshuffle(1)
707+
zeros = ptb.AllocEmpty(x.dtype)(y.shape[1])
708+
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
709+
new_out = [rval.dimshuffle("x", 0)]
710+
elif not xb[0] and not xb[1] and yb[1]:
711+
# x is matrix, y is vector, try gemv
712+
yv = y.dimshuffle(0)
713+
zeros = ptb.AllocEmpty(x.dtype)(x.shape[0])
714+
rval = gemv_no_inplace(zeros, one, x, yv, zero)
715+
new_out = [rval.dimshuffle(0, "x")]
716+
else:
717+
return
718+
copy_stack_trace(node.outputs, new_out)
719+
return new_out
720+
721+
726722
#################################
727723
#
728724
# Set up the BlasOpt optimizer

0 commit comments

Comments
 (0)