@@ -855,35 +855,44 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
855855
856856 # Root case, RNG is not used elsewhere
857857 if not rng_clients :
858- return rng
858+ return None
859859
860860 if len (rng_clients ) > 1 :
861861 # Multiple clients are techincally fine if they are used in identical operations
862862 # We check if the default_update of each client would be the same
863- update , * other_updates = (
863+ all_updates = [
864864 find_default_update (
865865 # Pass version of clients that includes only one the RNG clients at a time
866866 clients | {rng : [rng_client ]},
867867 rng ,
868868 )
869869 for rng_client in rng_clients
870- )
871- if all (equal_computations ([update ], [other_update ]) for other_update in other_updates ):
872- return update
873-
874- warnings .warn (
875- f"RNG Variable { rng } has multiple distinct clients { rng_clients } , "
876- f"likely due to an inconsistent random graph. "
877- f"No default update will be returned." ,
878- UserWarning ,
879- )
880- return None
870+ ]
871+ updates = [update for update in all_updates if update is not None ]
872+ if not updates :
873+ return None
874+ if len (updates ) == 1 :
875+ return updates [0 ]
876+ else :
877+ update , * other_updates = updates
878+ if all (
879+ equal_computations ([update ], [other_update ]) for other_update in other_updates
880+ ):
881+ return update
882+
883+ warnings .warn (
884+ f"RNG Variable { rng } has multiple distinct clients { rng_clients } , "
885+ f"likely due to an inconsistent random graph. "
886+ f"No default update will be returned." ,
887+ UserWarning ,
888+ )
889+ return None
881890
882891 [client , _ ] = rng_clients [0 ]
883892
884893 # RNG is an output of the function, this is not a problem
885894 if isinstance (client .op , Output ):
886- return rng
895+ return None
887896
888897 # RNG is used by another operator, which should output an update for the RNG
889898 if isinstance (client .op , RandomVariable ):
@@ -912,18 +921,26 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
912921 )
913922 elif isinstance (client .op , OpFromGraph ):
914923 try :
915- next_rng = collect_default_updates_inner_fgraph (client )[rng ]
916- except (ValueError , KeyError ):
924+ next_rng = collect_default_updates_inner_fgraph (client ).get (rng )
925+ if next_rng is None :
926+ # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning
927+ return None
928+ except ValueError as exc :
917929 raise ValueError (
918930 f"No update found for at least one RNG used in OpFromGraph Op { client .op } .\n "
919931 "You can use `pytensorf.collect_default_updates` and include those updates as outputs."
920- )
932+ ) from exc
921933 else :
922934 # We don't know how this RNG should be updated. The user should provide an update manually
923935 return None
924936
925937 # Recurse until we find final update for RNG
926- return find_default_update (clients , next_rng )
938+ nested_next_rng = find_default_update (clients , next_rng )
939+ if nested_next_rng is None :
940+ # There were no more uses of this next_rng
941+ return next_rng
942+ else :
943+ return nested_next_rng
927944
928945 if inputs is None :
929946 inputs = []
0 commit comments