Skip to content

Commit cb47516

Browse files
Merge pull request #2544 from AI-Hypercomputer:xibin/diagon_sdk
PiperOrigin-RevId: 839818265
2 parents 73e449d + ee281ce commit cb47516

File tree

11 files changed

+165
-8
lines changed

11 files changed

+165
-8
lines changed

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
gcsfs
99
google-api-python-client
1010
google-cloud-aiplatform
11+
google-cloud-mldiagnostics
1112
google-cloud-monitoring
1213
grain[parquet]
1314
huggingface_hub

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ google-cloud-audit-log>=0.4.0
6666
google-cloud-bigquery>=3.38.0
6767
google-cloud-core>=2.5.0
6868
google-cloud-logging>=3.12.1
69+
google-cloud-mldiagnostics>=0.5.5
6970
google-cloud-monitoring>=2.28.0
7071
google-cloud-resource-manager>=1.15.0
7172
google-cloud-storage>=3.6.0
7273
google-crc32c>=1.7.1
7374
google-genai>=1.52.0
74-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7575
google-pasta>=0.2.0
7676
google-resumable-media>=2.8.0
7777
googleapis-common-protos>=1.72.0
@@ -120,7 +120,6 @@ mdurl>=0.1.2
120120
ml-collections>=1.1.0
121121
ml-dtypes>=0.5.4
122122
ml-goodput-measurement>=0.0.15
123-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
124123
more-itertools>=10.8.0
125124
mpmath>=1.3.0
126125
msgpack>=1.1.2
@@ -255,4 +254,4 @@ xprof>=2.21.1
255254
xxhash>=3.6.0
256255
yarl>=1.22.0
257256
zipp>=3.23.0
258-
zstandard>=0.25.0
257+
zstandard>=0.25.0

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ google-cloud-audit-log>=0.4.0
6666
google-cloud-bigquery>=3.38.0
6767
google-cloud-core>=2.5.0
6868
google-cloud-logging>=3.12.1
69+
google-cloud-mldiagnostics>=0.5.5
6970
google-cloud-monitoring>=2.28.0
7071
google-cloud-resource-manager>=1.15.0
7172
google-cloud-storage>=3.6.0
7273
google-crc32c>=1.7.1
7374
google-genai>=1.52.0
74-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7575
google-pasta>=0.2.0
7676
google-resumable-media>=2.8.0
7777
google-tunix>=0.1.3
@@ -123,7 +123,6 @@ mdurl>=0.1.2
123123
ml-collections>=1.1.0
124124
ml-dtypes>=0.5.4
125125
ml-goodput-measurement>=0.0.15
126-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
127126
more-itertools>=10.8.0
128127
mpmath>=1.3.0
129128
msgpack>=1.1.2
@@ -245,4 +244,4 @@ xprof>=2.21.1
245244
xxhash>=3.6.0
246245
yarl>=1.22.0
247246
zipp>=3.23.0
248-
zstandard>=0.25.0
247+
zstandard>=0.25.0

dependencies/requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
gcsfs
99
google-api-python-client
1010
google-cloud-aiplatform
11+
google-cloud-mldiagnostics
1112
google-cloud-monitoring
1213
grain[parquet]
1314
huggingface_hub

dependencies/requirements/requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip
44
flax>=0.11.0
55
google-api-python-client
6+
google-cloud-mldiagnostics
67
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
78
grain[parquet]>=0.2.15
89
jaxtyping

src/MaxText/configs/base.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ colocated_python_data_input: False # experimental feature, under testing
612612

613613
# Training loop
614614
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
615-
log_period: 100 # Flushes Tensorboard
615+
log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.
616616

617617
jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
618618
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
@@ -658,6 +658,12 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state
658658
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
659659
# This is useful to debug scenarios where performance is changing.
660660

661+
# Managed ML diagnostics settings. If the feature is enabled, it will
662+
# - create a managed ML diagnostics run with all the MaxText configs
663+
# - upload xplane profiling, if it is enabled.
664+
# - upload training metrics, at the defined log_period interval.
665+
managed_mldiagnostics: False # Whether to enable the managed diagnostics
666+
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
661667

662668
# Dump HLO options
663669
dump_hlo: False

src/MaxText/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,13 @@ class Metrics(BaseModel):
11691169
)
11701170

11711171

1172+
class ManagedMLDiagnostics(BaseModel):
1173+
"""Configuration for managed mldiagnostics."""
1174+
1175+
managed_mldiagnostics: bool = Field(False, description="Enable managed mldiagnostics.")
1176+
managed_mldiagnostics_run_group: str = Field("", description="Name used to group multiple runs.")
1177+
1178+
11721179
class Goodput(BaseModel):
11731180
"""Configuration for goodput monitoring."""
11741181

@@ -1428,6 +1435,10 @@ class DerivedValues(BaseModel):
14281435
None,
14291436
description="The full path to the tensorboard directory, derived from `run_name`.",
14301437
)
1438+
managed_mldiagnostics_dir: None | str = Field(
1439+
None,
1440+
description="The full path to the managed mldiagnostics directory, derived from `run_name`.",
1441+
)
14311442

14321443
rampup_end_step: None | int = Field(None, description="The step at which the batch size ramp-up phase concludes.")
14331444
tensors_on_device: None | list[str] = Field(
@@ -1543,6 +1554,7 @@ class MaxTextConfig(
15431554
Goodput,
15441555
GcpMonitoring,
15451556
Tensorboard,
1557+
ManagedMLDiagnostics,
15461558
# Multimodal
15471559
MultimodalGeneral,
15481560
VisionTower,
@@ -1588,6 +1600,8 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
15881600
self.checkpoint_dir = os.path.join(output_dir, "checkpoints", "")
15891601
self.metrics_dir = os.path.join(output_dir, "metrics", "")
15901602
self.tensorboard_dir = os.path.join(output_dir, "tensorboard", "")
1603+
# To work around SDK bug b/454725283, remove the trailing back slash from the managed_mldiagnostics_dir.
1604+
self.managed_mldiagnostics_dir = os.path.join(output_dir, "managed-mldiagnostics")
15911605
else:
15921606
self.checkpoint_dir, self.metrics_dir, self.tensorboard_dir = None, None, None
15931607

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Create the managed mldiagnostics run."""
16+
import json
17+
from typing import Any
18+
19+
import google_cloud_mldiagnostics as mldiag
20+
21+
from MaxText.pyconfig import KEYS_NO_LOGGING
22+
23+
24+
class ManagedMLDiagnostics:
25+
"""
26+
ML Diagnostics Run, implemented with the Singleton pattern.
27+
Ensures that only one instance of the class can exist.
28+
"""
29+
30+
_instance = None # Class attribute to hold the single instance
31+
32+
def __new__(cls, *args: Any, **kwargs: Any):
33+
"""
34+
Overrides the instance creation method.
35+
If an instance already exists, it is returned instead of creating a new one.
36+
"""
37+
if cls._instance is None:
38+
cls._instance = super(ManagedMLDiagnostics, cls).__new__(cls)
39+
40+
return cls._instance
41+
42+
def __init__(self, config):
43+
"""
44+
Initializes the ManagedMLDiagnostics, ensuring this method runs only once.
45+
"""
46+
# We need a flag to ensure __init__ only runs once,
47+
# as the object is returned multiple times by __new__.
48+
if hasattr(self, "_initialized"):
49+
return
50+
self._initialized = True
51+
if not config.managed_mldiagnostics:
52+
return
53+
54+
# Set up the managed mldiagnostics for profiling and metrics uploading.
55+
def should_log_key(key, value):
56+
if key in KEYS_NO_LOGGING:
57+
return False
58+
try:
59+
# Verify the value can be serialized to json. If not, we'll skip it.
60+
json.dumps(value, allow_nan=False)
61+
except TypeError:
62+
return False
63+
return True
64+
65+
config_dict = {key: value for key, value in config.get_keys().items() if should_log_key(key, value)}
66+
67+
# Create a run for the managed mldiagnostics, and upload the configuration.
68+
mldiag.machinelearning_run(
69+
name=f"{config.run_name}",
70+
run_group=config.managed_mldiagnostics_run_group,
71+
configs=config_dict,
72+
gcs_path=config.managed_mldiagnostics_dir,
73+
# TODO: b/455623960 - Remove the following once multi-region and prod support are enabled.
74+
region="us-central1",
75+
environment="autopush", # Default would be "prod" for formal launch.
76+
)

src/MaxText/metric_logger.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,31 @@
2525

2626
import jax
2727

28+
import google_cloud_mldiagnostics as mldiag
29+
2830
from MaxText import max_logging
2931
from MaxText import max_utils
3032
from MaxText import maxtext_utils
33+
from MaxText.managed_mldiagnostics import ManagedMLDiagnostics
3134
from MaxText.utils import gcs_utils
3235
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
3336
from MaxText.globals import EPS
3437

3538
from collections import defaultdict
3639

40+
# Mapping MaxText metrics to managed profiler metrics
41+
_METRICS_TO_MANAGED = {
42+
"learning/current_learning_rate": "learning_rate",
43+
"learning/loss": "loss",
44+
"learning/grad_norm": "gradient_norm",
45+
"learning/total_weights": "total_weights",
46+
"perf/step_time_seconds": "step_time",
47+
"perf/per_device_tokens_per_sec": "throughput",
48+
"perf/per_device_tflops_per_sec": "tflops",
49+
# There are no mappings to the following metrics yet:
50+
# "latency", "mfu"
51+
}
52+
3753

3854
def _prepare_metrics_for_json(metrics, step, run_name):
3955
"""Converts metric dictionary into json supported types (e.g. float)"""
@@ -82,6 +98,8 @@ def __init__(self, config, learning_rate_schedule):
8298
self.learning_rate_schedule = learning_rate_schedule
8399
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
84100
self.buffered_train_metrics = None
101+
if self.config.managed_mldiagnostics:
102+
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
85103

86104
def reset_eval_metrics(self):
87105
"""Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -101,6 +119,9 @@ def write_metrics(self, metrics, step, is_training=True):
101119
if self.config.gcs_metrics and jax.process_index() == 0:
102120
self.write_metrics_for_gcs(metrics, step, is_training)
103121

122+
if self.config.managed_mldiagnostics:
123+
self.write_metrics_to_managed_mldiagnostics(metrics, step)
124+
104125
def log_metrics(self, metrics, step, is_training):
105126
"""Logs metrics via max_logging."""
106127
if is_training:
@@ -233,6 +254,18 @@ def write_metrics_to_tensorboard(self, metrics, step, is_training):
233254
max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'")
234255
self.writer.flush()
235256

257+
def write_metrics_to_managed_mldiagnostics(self, metrics, step):
258+
"""Write metrics to managed profiler."""
259+
if (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
260+
for metric_name in metrics.get("scalar", []):
261+
value = metrics["scalar"][metric_name]
262+
# For NumPy/JAX array objects (including single-element arrays), use .item()
263+
# to extract the native Python scalar.
264+
if hasattr(value, "item"):
265+
value = value.item()
266+
mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name)
267+
mldiag.metrics.record(mapped_metric_name, value, step=int(step))
268+
236269
def write_setup_info_to_tensorboard(self, params):
237270
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
238271
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)

src/MaxText/profiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121

2222
import jax
2323

24+
import google_cloud_mldiagnostics as mldiag
25+
2426
from MaxText import max_logging
27+
from MaxText.managed_mldiagnostics import ManagedMLDiagnostics
2528

2629

2730
class Profiler:
@@ -40,6 +43,10 @@ def __init__(self, config, offset_step=0):
4043
self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps)
4144
if config.profiler != "" and self.start_initial_profile_step >= config.steps:
4245
raise ValueError("Profiling requested but initial profiling step set past training final step")
46+
self.prof = None # managed mldiagnostics xprof collector.
47+
self.managed_mldiagnostics = config.managed_mldiagnostics
48+
if config.managed_mldiagnostics:
49+
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
4350

4451
def maybe_activate_profiler(self, step, state):
4552
"""Conditionally activates the profiler based on the current step.
@@ -56,6 +63,16 @@ def activate(self, blocking_object=None, optional_postfix=""):
5663
nsys profiler becomes no-op when libcudart.so is not available on the system."""
5764
if self.profile_cleanly and blocking_object is not None:
5865
jax.block_until_ready(blocking_object)
66+
67+
if self.managed_mldiagnostics and self.mode == "xplane":
68+
# Handle the special profiling logic for managed_mldiagnostics
69+
if self.prof is None:
70+
# Starts xprof collector.
71+
# Only profiling on the first device, if not upload_all_profiler_results. None is for all devices.
72+
self.prof = mldiag.xprof(process_index_list=None if self.upload_all_profiler_results else [0])
73+
self.prof.start()
74+
return
75+
5976
if not (self.upload_all_profiler_results or jax.process_index() == 0):
6077
return
6178
if self.mode != "":
@@ -84,6 +101,13 @@ def deactivate(self, blocking_object=None):
84101
The result is uploaded to the output bucket."""
85102
if self.profile_cleanly and blocking_object is not None:
86103
jax.block_until_ready(blocking_object)
104+
105+
if self.managed_mldiagnostics and self.mode == "xplane":
106+
# Handle the special profileing logic for managed_mldiagnostics
107+
if self.prof is not None:
108+
self.prof.stop()
109+
return
110+
87111
if not (self.upload_all_profiler_results or jax.process_index() == 0):
88112
return
89113
if self.mode == "nsys":

0 commit comments

Comments
 (0)