You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
steps: 150_001# If set to -1 then will inherit value from learning_rate_schedule_steps
615
-
log_period: 100#Flushes Tensorboard
615
+
log_period: 100#The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.
616
616
617
617
jax_distributed_initialization_timeout: 300# This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
618
618
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
@@ -658,6 +658,12 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state
658
658
profile_periodically_period: -1# If set to a positive integer, profile every profile_periodically_period steps.
659
659
# This is useful to debug scenarios where performance is changing.
660
660
661
+
# Managed ML diagnostics settings. If the feature is enabled, it will
662
+
# - create a managed ML diagnostics run with all the MaxText configs
663
+
# - upload xplane profiling, if it is enabled.
664
+
# - upload training metrics, at the defined log_period interval.
665
+
managed_mldiagnostics: False # Whether to enable the managed diagnostics
666
+
managed_mldiagnostics_run_group: ""# Optional. Used to group multiple runs.
0 commit comments