Skip to content

Commit 7dd401c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix save_from_both_policies in presence of save_and_offload_only_these_names by comparing the enum
PiperOrigin-RevId: 706874882
1 parent 772339e commit 7dd401c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,14 @@ def policy(prim, *_, **params):
142142
def save_from_both_policies(policy_1, policy_2):
143143

144144
def policy(prim, *args, **params):
145-
return policy_1(prim, *args, **params) or policy_2(prim, *args, **params)
146-
145+
out1 = policy_1(prim, *args, **params)
146+
out2 = policy_2(prim, *args, **params)
147+
if not (isinstance(out1, bool) and isinstance(out2, bool)):
148+
raise ValueError(
149+
"The return value of the policies should be a boolean. Got:"
150+
f" {out1} and {out2}. Please write a custom policy function directly,"
151+
" rather than using this helper function.")
152+
return out1 or out2
147153
return policy
148154

149155

0 commit comments

Comments
 (0)