@@ -636,93 +636,89 @@ def local_inplace_ger(fgraph, node):
636
636
@node_rewriter ([gemm_no_inplace ])
637
637
def local_gemm_to_gemv (fgraph , node ):
638
638
"""GEMM acting on row or column matrices -> GEMV."""
639
- if node .op == gemm_no_inplace :
640
- z , a , x , y , b = node .inputs
641
- if z .broadcastable == x .broadcastable == (True , False ):
642
- r = gemv_no_inplace (z .dimshuffle (1 ), a , y .T , x .dimshuffle (1 ), b )
643
- new_out = [r .dimshuffle ("x" , 0 )]
644
- elif z .broadcastable == y .broadcastable == (False , True ):
645
- r = gemv_no_inplace (z .dimshuffle (0 ), a , x , y .dimshuffle (0 ), b )
646
- new_out = [r .dimshuffle (0 , "x" )]
647
- else :
648
- return
649
- copy_stack_trace (node .outputs , new_out )
650
- return new_out
639
+ z , a , x , y , b = node .inputs
640
+ if z .broadcastable == x .broadcastable == (True , False ):
641
+ r = gemv_no_inplace (z .dimshuffle (1 ), a , y .T , x .dimshuffle (1 ), b )
642
+ new_out = [r .dimshuffle ("x" , 0 )]
643
+ elif z .broadcastable == y .broadcastable == (False , True ):
644
+ r = gemv_no_inplace (z .dimshuffle (0 ), a , x , y .dimshuffle (0 ), b )
645
+ new_out = [r .dimshuffle (0 , "x" )]
646
+ else :
647
+ return
648
+ copy_stack_trace (node .outputs , new_out )
649
+ return new_out
651
650
652
651
653
652
@node_rewriter ([gemm_no_inplace ])
654
653
def local_gemm_to_ger (fgraph , node ):
655
654
"""GEMM computing an outer-product -> GER."""
656
- if node .op == gemm_no_inplace :
657
- z , a , x , y , b = node .inputs
658
- if x .broadcastable [1 ] and y .broadcastable [0 ]:
659
- # x and y are both vectors so this might qualifies for a GER
660
- xv = x .dimshuffle (0 )
661
- yv = y .dimshuffle (1 )
662
- try :
663
- bval = ptb .get_underlying_scalar_constant_value (b )
664
- except NotScalarConstantError :
665
- # b isn't a constant, GEMM is doing useful pre-scaling
666
- return
667
-
668
- if bval == 1 : # best case a natural GER
669
- rval = ger (z , a , xv , yv )
670
- new_out = [rval ]
671
- elif bval == 0 : # GER on zeros_like should be faster than GEMM
672
- zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], x .dtype )
673
- rval = ger (zeros , a , xv , yv )
674
- new_out = [rval ]
675
- else :
676
- # if bval is another constant, then z is being usefully
677
- # pre-scaled and GER isn't really the right tool for the job.
678
- return
679
- copy_stack_trace (node .outputs , new_out )
680
- return new_out
681
-
655
+ z , a , x , y , b = node .inputs
656
+ if x .broadcastable [1 ] and y .broadcastable [0 ]:
657
+ # x and y are both vectors so this might qualifies for a GER
658
+ xv = x .dimshuffle (0 )
659
+ yv = y .dimshuffle (1 )
660
+ try :
661
+ bval = ptb .get_underlying_scalar_constant_value (b )
662
+ except NotScalarConstantError :
663
+ # b isn't a constant, GEMM is doing useful pre-scaling
664
+ return
682
665
683
- # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
684
- # working
685
- @node_rewriter ([_dot22 ])
686
- def local_dot22_to_ger_or_gemv (fgraph , node ):
687
- """dot22 computing an outer-product -> GER."""
688
- if node .op == _dot22 :
689
- x , y = node .inputs
690
- xb = x .broadcastable
691
- yb = y .broadcastable
692
- one = ptb .as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
693
- zero = ptb .as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
694
- if xb [1 ] and yb [0 ]:
695
- # x and y are both vectors so this might qualifies for a GER
696
- xv = x .dimshuffle (0 )
697
- yv = y .dimshuffle (1 )
698
- zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
699
- rval = ger (zeros , one , xv , yv )
666
+ if bval == 1 : # best case a natural GER
667
+ rval = ger (z , a , xv , yv )
668
+ new_out = [rval ]
669
+ elif bval == 0 : # GER on zeros_like should be faster than GEMM
670
+ zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], x .dtype )
671
+ rval = ger (zeros , a , xv , yv )
700
672
new_out = [rval ]
701
- elif xb [0 ] and yb [1 ]:
702
- # x and y are both vectors so this qualifies for a sdot / ddot
703
- # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
704
- xv = x .dimshuffle (1 )
705
- zeros = ptb .AllocEmpty (x .dtype )(1 )
706
- rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
707
- new_out = [rval .dimshuffle ("x" , 0 )]
708
- elif xb [0 ] and not yb [0 ] and not yb [1 ]:
709
- # x is vector, y is matrix so try gemv
710
- xv = x .dimshuffle (1 )
711
- zeros = ptb .AllocEmpty (x .dtype )(y .shape [1 ])
712
- rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
713
- new_out = [rval .dimshuffle ("x" , 0 )]
714
- elif not xb [0 ] and not xb [1 ] and yb [1 ]:
715
- # x is matrix, y is vector, try gemv
716
- yv = y .dimshuffle (0 )
717
- zeros = ptb .AllocEmpty (x .dtype )(x .shape [0 ])
718
- rval = gemv_no_inplace (zeros , one , x , yv , zero )
719
- new_out = [rval .dimshuffle (0 , "x" )]
720
673
else :
674
+ # if bval is another constant, then z is being usefully
675
+ # pre-scaled and GER isn't really the right tool for the job.
721
676
return
722
677
copy_stack_trace (node .outputs , new_out )
723
678
return new_out
724
679
725
680
681
+ # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working
682
+ @node_rewriter ([_dot22 ])
683
+ def local_dot22_to_ger_or_gemv (fgraph , node ):
684
+ """dot22 computing an outer-product -> GER."""
685
+ x , y = node .inputs
686
+ xb = x .broadcastable
687
+ yb = y .broadcastable
688
+ one = ptb .as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
689
+ zero = ptb .as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
690
+ if xb [1 ] and yb [0 ]:
691
+ # x and y are both vectors so this might qualifies for a GER
692
+ xv = x .dimshuffle (0 )
693
+ yv = y .dimshuffle (1 )
694
+ zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
695
+ rval = ger (zeros , one , xv , yv )
696
+ new_out = [rval ]
697
+ elif xb [0 ] and yb [1 ]:
698
+ # x and y are both vectors so this qualifies for a sdot / ddot
699
+ # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
700
+ xv = x .dimshuffle (1 )
701
+ zeros = ptb .AllocEmpty (x .dtype )(1 )
702
+ rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
703
+ new_out = [rval .dimshuffle ("x" , 0 )]
704
+ elif xb [0 ] and not yb [0 ] and not yb [1 ]:
705
+ # x is vector, y is matrix so try gemv
706
+ xv = x .dimshuffle (1 )
707
+ zeros = ptb .AllocEmpty (x .dtype )(y .shape [1 ])
708
+ rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
709
+ new_out = [rval .dimshuffle ("x" , 0 )]
710
+ elif not xb [0 ] and not xb [1 ] and yb [1 ]:
711
+ # x is matrix, y is vector, try gemv
712
+ yv = y .dimshuffle (0 )
713
+ zeros = ptb .AllocEmpty (x .dtype )(x .shape [0 ])
714
+ rval = gemv_no_inplace (zeros , one , x , yv , zero )
715
+ new_out = [rval .dimshuffle (0 , "x" )]
716
+ else :
717
+ return
718
+ copy_stack_trace (node .outputs , new_out )
719
+ return new_out
720
+
721
+
726
722
#################################
727
723
#
728
724
# Set up the BlasOpt optimizer
0 commit comments