You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- include new SDK google-cloud-mldiagnostics
- add new config params
- add new file managed_mldiagnostics.py
- modify profiler.py to profile with mldiagnostics
- modify metrics_logger.py to upload metrics
steps: 150_001# If set to -1 then will inherit value from learning_rate_schedule_steps
613
-
log_period: 100#Flushes Tensorboard
613
+
log_period: 100#The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.
614
614
615
615
jax_distributed_initialization_timeout: 300# This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
616
616
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
@@ -656,6 +656,12 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state
656
656
profile_periodically_period: -1# If set to a positive integer, profile every profile_periodically_period steps.
657
657
# This is useful to debug scenarios where performance is changing.
658
658
659
+
# Managed ML diagnostics settings. If the feature is enabled, it will
660
+
# - create a managed ML diagnostics run with all the MaxText configs
661
+
# - upload xplane profiling, if it is enabled.
662
+
# - upload training metrics, at the defined log_period interval.
663
+
managed_mldiagnostics: False # Whether to enable the managed diagnostics
664
+
managed_mldiagnostics_run_group: ""# Optional. Used to group multiple runs.
0 commit comments