Skip to content

Commit ee281ce

Browse files
committed
Integrate SDK for managed profiler
- include new SDK google-cloud-mldiagnostics - add new config params - add new file managed_mldiagnostics.py - modify profiler.py to profile with mldiagnostics - modify metrics_logger.py to upload metrics
1 parent ed517cf commit ee281ce

File tree

11 files changed

+169
-11
lines changed

11 files changed

+169
-11
lines changed

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
gcsfs
99
google-api-python-client
1010
google-cloud-aiplatform
11+
google-cloud-mldiagnostics
1112
google-cloud-monitoring
1213
grain[parquet]
1314
huggingface_hub

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ google-cloud-audit-log>=0.4.0
6666
google-cloud-bigquery>=3.38.0
6767
google-cloud-core>=2.5.0
6868
google-cloud-logging>=3.12.1
69+
google-cloud-mldiagnostics>=0.5.5
6970
google-cloud-monitoring>=2.28.0
7071
google-cloud-resource-manager>=1.15.0
7172
google-cloud-storage>=3.6.0
7273
google-crc32c>=1.7.1
7374
google-genai>=1.52.0
74-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7575
google-pasta>=0.2.0
7676
google-resumable-media>=2.8.0
7777
googleapis-common-protos>=1.72.0
@@ -120,7 +120,6 @@ mdurl>=0.1.2
120120
ml-collections>=1.1.0
121121
ml-dtypes>=0.5.3
122122
ml-goodput-measurement>=0.0.15
123-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
124123
more-itertools>=10.8.0
125124
mpmath>=1.3.0
126125
msgpack>=1.1.2
@@ -255,4 +254,4 @@ xprof>=2.21.1
255254
xxhash>=3.6.0
256255
yarl>=1.22.0
257256
zipp>=3.23.0
258-
zstandard>=0.25.0
257+
zstandard>=0.25.0

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ cloudpickle>=3.1.1
3232
clu>=0.0.12
3333
colorama>=0.4.6
3434
contourpy>=1.3.3
35-
coverage>=7.11.3
35+
coverage>=7.12.0
3636
cycler>=0.12.1
3737
datasets>=4.4.1
3838
decorator>=5.2.1
@@ -66,12 +66,12 @@ google-cloud-audit-log>=0.4.0
6666
google-cloud-bigquery>=3.38.0
6767
google-cloud-core>=2.5.0
6868
google-cloud-logging>=3.12.1
69+
google-cloud-mldiagnostics>=0.5.5
6970
google-cloud-monitoring>=2.28.0
7071
google-cloud-resource-manager>=1.15.0
7172
google-cloud-storage>=3.6.0
7273
google-crc32c>=1.7.1
73-
google-genai>=1.50.1
74-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
74+
google-genai>=1.51.0
7575
google-pasta>=0.2.0
7676
google-resumable-media>=2.8.0
7777
google-tunix>=0.1.3
@@ -123,7 +123,6 @@ mdurl>=0.1.2
123123
ml-collections>=1.1.0
124124
ml-dtypes>=0.5.3
125125
ml-goodput-measurement>=0.0.15
126-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
127126
more-itertools>=10.8.0
128127
mpmath>=1.3.0
129128
msgpack>=1.1.2
@@ -220,7 +219,7 @@ tensorflow>=2.19.1
220219
tensorstore>=0.1.78
221220
termcolor>=3.1.0
222221
tiktoken>=0.12.0
223-
tokamax>=0.0.5
222+
tokamax>=0.0.6
224223
tokenizers>=0.22.1
225224
toml>=0.10.2
226225
tomlkit>=0.13.3
@@ -245,4 +244,4 @@ xprof>=2.21.1
245244
xxhash>=3.6.0
246245
yarl>=1.22.0
247246
zipp>=3.23.0
248-
zstandard>=0.25.0
247+
zstandard>=0.25.0

dependencies/requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
gcsfs
99
google-api-python-client
1010
google-cloud-aiplatform
11+
google-cloud-mldiagnostics
1112
google-cloud-monitoring
1213
grain[parquet]
1314
huggingface_hub

dependencies/requirements/requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip
44
flax>=0.11.0
55
google-api-python-client
6+
google-cloud-mldiagnostics
67
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
78
grain[parquet]>=0.2.13
89
jaxtyping

src/MaxText/configs/base.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ colocated_python_data_input: False # experimental feature, under testing
610610

611611
# Training loop
612612
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
613-
log_period: 100 # Flushes Tensorboard
613+
log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.
614614

615615
jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
616616
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
@@ -656,6 +656,12 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state
656656
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
657657
# This is useful to debug scenarios where performance is changing.
658658

659+
# Managed ML diagnostics settings. If the feature is enabled, it will
660+
# - create a managed ML diagnostics run with all the MaxText configs
661+
# - upload xplane profiling, if it is enabled.
662+
# - upload training metrics, at the defined log_period interval.
663+
managed_mldiagnostics: False # Whether to enable the managed diagnostics
664+
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
659665

660666
# Dump HLO options
661667
dump_hlo: False

src/MaxText/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,13 @@ class Metrics(BaseModel):
11651165
)
11661166

11671167

1168+
class ManagedMLDiagnostics(BaseModel):
1169+
"""Configuration for managed mldiagnostics."""
1170+
1171+
managed_mldiagnostics: bool = Field(False, description="Enable managed mldiagnostics.")
1172+
managed_mldiagnostics_run_group: str = Field("", description="Name used to group multiple runs.")
1173+
1174+
11681175
class Goodput(BaseModel):
11691176
"""Configuration for goodput monitoring."""
11701177

@@ -1419,6 +1426,10 @@ class DerivedValues(BaseModel):
14191426
None,
14201427
description="The full path to the tensorboard directory, derived from `run_name`.",
14211428
)
1429+
managed_mldiagnostics_dir: None | str = Field(
1430+
None,
1431+
description="The full path to the managed mldiagnostics directory, derived from `run_name`.",
1432+
)
14221433

14231434
rampup_end_step: None | int = Field(None, description="The step at which the batch size ramp-up phase concludes.")
14241435
tensors_on_device: None | list[str] = Field(
@@ -1534,6 +1545,7 @@ class MaxTextConfig(
15341545
Goodput,
15351546
GcpMonitoring,
15361547
Tensorboard,
1548+
ManagedMLDiagnostics,
15371549
# Multimodal
15381550
MultimodalGeneral,
15391551
VisionTower,
@@ -1579,6 +1591,8 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
15791591
self.checkpoint_dir = os.path.join(output_dir, "checkpoints", "")
15801592
self.metrics_dir = os.path.join(output_dir, "metrics", "")
15811593
self.tensorboard_dir = os.path.join(output_dir, "tensorboard", "")
1594+
# To work around SDK bug b/454725283, remove the trailing back slash from the managed_mldiagnostics_dir.
1595+
self.managed_mldiagnostics_dir = os.path.join(output_dir, "managed-mldiagnostics")
15821596
else:
15831597
self.checkpoint_dir, self.metrics_dir, self.tensorboard_dir = None, None, None
15841598

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Create the managed mldiagnostics run."""
16+
from typing import Any
17+
18+
import simplejson as json
19+
20+
from google_cloud_mldiagnostics import machinelearning_run
21+
22+
from MaxText.pyconfig import KEYS_NO_LOGGING
23+
24+
25+
class ManagedMLDiagnostics:
26+
"""
27+
ML Diagnostics Run, implemented with the Singleton pattern.
28+
Ensures that only one instance of the class can exist.
29+
"""
30+
31+
_instance = None # Class attribute to hold the single instance
32+
33+
def __new__(cls, *args: Any, **kwargs: Any):
34+
"""
35+
Overrides the instance creation method.
36+
If an instance already exists, it is returned instead of creating a new one.
37+
"""
38+
if cls._instance is None:
39+
cls._instance = super(ManagedMLDiagnostics, cls).__new__(cls)
40+
41+
return cls._instance
42+
43+
def __init__(self, config):
44+
"""
45+
Initializes the ManagedMLDiagnostics, ensuring this method runs only once.
46+
"""
47+
# We need a flag to ensure __init__ only runs once,
48+
# as the object is returned multiple times by __new__.
49+
if hasattr(self, "_initialized"):
50+
return
51+
self._initialized = True
52+
if not config.managed_mldiagnostics:
53+
return
54+
55+
# Set up the managed mldiagnostics for profiling and metrics uploading.
56+
def should_log_key(key, value):
57+
if key in KEYS_NO_LOGGING:
58+
return False
59+
try:
60+
# Verify the value can be serialized to json. If not, we'll skip it.
61+
json.dumps(value, allow_nan=False)
62+
except TypeError:
63+
return False
64+
return True
65+
66+
config_dict = {key: value for key, value in config.get_keys().items() if should_log_key(key, value)}
67+
68+
# Create a run for the managed mldiagnostics, and upload the configuration.
69+
machinelearning_run(
70+
name=f"{config.run_name}",
71+
run_group=config.managed_mldiagnostics_run_group,
72+
configs=config_dict,
73+
gcs_path=config.managed_mldiagnostics_dir,
74+
# TODO: b/455623960 - Remove the following once multi-region and prod support are enabled.
75+
region="us-central1",
76+
environment="autopush", # Default would be "prod" for formal launch.
77+
)

src/MaxText/metric_logger.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,31 @@
2525

2626
import jax
2727

28+
from google_cloud_mldiagnostics import metrics as mlmetrics
29+
2830
from MaxText import max_logging
2931
from MaxText import max_utils
3032
from MaxText import maxtext_utils
33+
from MaxText.managed_mldiagnostics import ManagedMLDiagnostics
3134
from MaxText.utils import gcs_utils
3235
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
3336
from MaxText.globals import EPS
3437

3538
from collections import defaultdict
3639

40+
# Mapping MaxText metrics to managed profiler metrics
41+
_METRICS_TO_MANAGED = {
42+
"learning/current_learning_rate": "learning_rate",
43+
"learning/loss": "loss",
44+
"learning/grad_norm": "gradient_norm",
45+
"learning/total_weights": "total_weights",
46+
"perf/step_time_seconds": "step_time",
47+
"perf/per_device_tokens_per_sec": "throughput",
48+
"perf/per_device_tflops_per_sec": "tflops",
49+
# There are no mappings to the following metrics yet:
50+
# "latency", "mfu"
51+
}
52+
3753

3854
def _prepare_metrics_for_json(metrics, step, run_name):
3955
"""Converts metric dictionary into json supported types (e.g. float)"""
@@ -82,6 +98,8 @@ def __init__(self, config, learning_rate_schedule):
8298
self.learning_rate_schedule = learning_rate_schedule
8399
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
84100
self.buffered_train_metrics = None
101+
if self.config.managed_mldiagnostics:
102+
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
85103

86104
def reset_eval_metrics(self):
87105
"""Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -101,6 +119,9 @@ def write_metrics(self, metrics, step, is_training=True):
101119
if self.config.gcs_metrics and jax.process_index() == 0:
102120
self.write_metrics_for_gcs(metrics, step, is_training)
103121

122+
if self.config.managed_mldiagnostics:
123+
self.write_metrics_to_managed_mldiagnostics(metrics, step)
124+
104125
def log_metrics(self, metrics, step, is_training):
105126
"""Logs metrics via max_logging."""
106127
if is_training:
@@ -233,6 +254,18 @@ def write_metrics_to_tensorboard(self, metrics, step, is_training):
233254
max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'")
234255
self.writer.flush()
235256

257+
def write_metrics_to_managed_mldiagnostics(self, metrics, step):
258+
"""Write metrics to managed profiler."""
259+
if (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
260+
for metric_name in metrics.get("scalar", []):
261+
value = metrics["scalar"][metric_name]
262+
# For NumPy/JAX array objects (including single-element arrays), use .item()
263+
# to extract the native Python scalar.
264+
if hasattr(value, "item"):
265+
value = value.item()
266+
mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name)
267+
mlmetrics.record(mapped_metric_name, value, step=int(step))
268+
236269
def write_setup_info_to_tensorboard(self, params):
237270
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
238271
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)

src/MaxText/profiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121

2222
import jax
2323

24+
from google_cloud_mldiagnostics import xprof
25+
2426
from MaxText import max_logging
27+
from MaxText.managed_mldiagnostics import ManagedMLDiagnostics
2528

2629

2730
class Profiler:
@@ -40,6 +43,10 @@ def __init__(self, config, offset_step=0):
4043
self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps)
4144
if config.profiler != "" and self.start_initial_profile_step >= config.steps:
4245
raise ValueError("Profiling requested but initial profiling step set past training final step")
46+
self.prof = None # managed mldiagnostics xprof collector.
47+
self.managed_mldiagnostics = config.managed_mldiagnostics
48+
if config.managed_mldiagnostics:
49+
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
4350

4451
def maybe_activate_profiler(self, step, state):
4552
"""Conditionally activates the profiler based on the current step.
@@ -56,6 +63,16 @@ def activate(self, blocking_object=None, optional_postfix=""):
5663
nsys profiler becomes no-op when libcudart.so is not available on the system."""
5764
if self.profile_cleanly and blocking_object is not None:
5865
jax.block_until_ready(blocking_object)
66+
67+
if self.managed_mldiagnostics and self.mode == "xplane":
68+
# Handle the special profiling logic for managed_mldiagnostics
69+
if self.prof is None:
70+
# Starts xprof collector.
71+
# Only profiling on the first device, if not upload_all_profiler_results. None is for all devices.
72+
self.prof = xprof(process_index_list=None if self.upload_all_profiler_results else [0])
73+
self.prof.start()
74+
return
75+
5976
if not (self.upload_all_profiler_results or jax.process_index() == 0):
6077
return
6178
if self.mode != "":
@@ -84,6 +101,13 @@ def deactivate(self, blocking_object=None):
84101
The result is uploaded to the output bucket."""
85102
if self.profile_cleanly and blocking_object is not None:
86103
jax.block_until_ready(blocking_object)
104+
105+
if self.managed_mldiagnostics and self.mode == "xplane":
106+
# Handle the special profileing logic for managed_mldiagnostics
107+
if self.prof is not None:
108+
self.prof.stop()
109+
return
110+
87111
if not (self.upload_all_profiler_results or jax.process_index() == 0):
88112
return
89113
if self.mode == "nsys":

0 commit comments

Comments
 (0)