@@ -985,26 +985,30 @@ def check_log_abs_det(fgraph, client):
985985
986986@node_rewriter (tracks = [det ])
987987def slogdet_specialization (fgraph , node ):
988- x = node .inputs [0 ]
989- sign_det_x , slog_det_x = SLogDet ()(x )
990988 replacements = {}
991989 for client in fgraph .clients [node .outputs [0 ]]:
992990 # Check for sign(det)
993991 if isinstance (client [0 ].op , Elemwise ) and isinstance (
994992 client [0 ].op .scalar_op , Sign
995993 ):
996- replacements [client [0 ].owner .outputs [0 ]] = sign_det_x
994+ x = node .inputs [0 ]
995+ sign_det_x , slog_det_x = SLogDet ()(x )
996+ replacements [client [0 ].outputs [0 ]] = sign_det_x
997997
998998 # Check for log(abs(det))
999999 elif check_log_abs_det (fgraph , client [0 ]):
1000- replacements [client [0 ].owner .outputs [0 ]] = slog_det_x
1000+ x = node .inputs [0 ]
1001+ sign_det_x , slog_det_x = SLogDet ()(x )
1002+ replacements [fgraph .clients [client [0 ].outputs [0 ]][0 ][0 ].outputs [0 ]] = (
1003+ slog_det_x
1004+ )
10011005
10021006 # Check for log(det)
1003- elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1004- client [0 ].op .scalar_op , Log
1005- ):
1006- pass
1007- # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1007+ # elif isinstance(client[0].op, Elemwise) and isinstance(
1008+ # client[0].op.scalar_op, Log
1009+ # ):
1010+ # pass
1011+ # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
10081012
10091013 # Det is used directly for something else, don't rewrite to avoid computing two dets
10101014 else :
0 commit comments