File tree Expand file tree Collapse file tree 2 files changed +48
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -1020,3 +1020,29 @@ def slogdet_specialization(fgraph, node):
10201020 k : slogdet_specialization_map [v ] for k , v in dummy_replacements .items ()
10211021 }
10221022 return replacements
1023+
1024+
1025+ @register_stabilize
1026+ @register_canonicalize
1027+ @node_rewriter ([Blockwise ])
1028+ def scalar_solve_to_divison (fgraph , node ):
1029+ """
1030+ Replace solve(a, b) with b / a if a is a (1, 1) matrix
1031+ """
1032+
1033+ core_op = node .op .core_op
1034+ if not isinstance (core_op , Solve ):
1035+ return None
1036+
1037+ a , b = node .inputs
1038+ old_out = node .outputs [0 ]
1039+ if not all (a .broadcastable [- 2 :]):
1040+ return None
1041+
1042+ new_out = b / a
1043+ if core_op .b_ndim == 1 :
1044+ new_out = new_out .squeeze (- 1 )
1045+
1046+ copy_stack_trace (old_out , new_out )
1047+
1048+ return [new_out ]
Original file line number Diff line number Diff line change @@ -993,3 +993,25 @@ def test_slogdet_specialization():
993993 f = function ([x ], [exp_det_x , sign_det_x ], mode = "FAST_RUN" )
994994 nodes = f .maker .fgraph .apply_nodes
995995 assert not any (isinstance (node .op , SLogDet ) for node in nodes )
996+
997+
998+ def test_scalar_solve_to_division_rewrite ():
999+ rng = np .random .default_rng (sum (map (ord , "scalar_solve_to_division_rewrite" )))
1000+
1001+ a = pt .dmatrix ("a" , shape = (1 , 1 ))
1002+ b = pt .dvector ("b" )
1003+
1004+ c = pt .linalg .solve (a , b , b_ndim = 1 )
1005+
1006+ f = function ([a , b ], c , mode = "FAST_RUN" )
1007+ nodes = f .maker .fgraph .apply_nodes
1008+
1009+ assert not any (isinstance (node .op , Solve ) for node in nodes )
1010+
1011+ a_val = rng .normal (size = (1 , 1 )).astype (pytensor .config .floatX )
1012+ b_val = rng .normal (size = (1 ,)).astype (pytensor .config .floatX )
1013+
1014+ c_val = np .linalg .solve (a_val , b_val )
1015+ np .testing .assert_allclose (
1016+ f (a_val , b_val ), c_val , rtol = 1e-7 if config .floatX == "float64" else 1e-5
1017+ )
You can’t perform that action at this time.
0 commit comments