Skip to content

Commit b2fcd8e

Browse files
committed
Using function in compilation stage
Signed-off-by: Christian Sarmiento <cfsarmiento03@gmail.com>
1 parent 82859b8 commit b2fcd8e

File tree

3 files changed

+58
-53
lines changed

3 files changed

+58
-53
lines changed

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444
get_programs_prompts,
4545
)
4646
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
47-
from aiu_fms_testing_utils.utils.resource_collection import instantiate_prometheus
48-
47+
from aiu_fms_testing_utils.utils.resource_collection import (
48+
instantiate_prometheus, print_step
49+
)
4950
# Constants
5051
PAD_MULTIPLE = 64
5152

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
1515
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
1616
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
17-
from aiu_fms_testing_utils.utils.resource_collection import (
18-
get_static_read, get_peak_read
19-
)
17+
from aiu_fms_testing_utils.utils.resource_collection import print_step
2018
from fms.utils.generation import pad_input_ids
2119
import torch
2220
import torch.nn as nn
@@ -50,38 +48,6 @@ def stagger_region(limit: int):
5048
torch.distributed.barrier()
5149
dprint("Stagger: All Complete")
5250

53-
def timestamp_print(given_string):
54-
"""
55-
Helper method that will add a timestamp before the given string that needs to be
56-
printed.
57-
58-
Args:
59-
- given_string: the string that is to be printed with the timestamp.
60-
"""
61-
62-
timestamp = datetime.now().strftime("%Y-%m-%d:%H:%M:%S")
63-
print(f"[{timestamp}] {given_string}")
64-
65-
def print_comp_resource_metrics(cpu_val, mem_val, stage, step):
66-
"""
67-
Helper method that will do a timestamp print for a specific step to report resource
68-
usage.
69-
70-
Args:
71-
- cpu_val: the value for CPU usage as a percentage that we want to print.
72-
- mem_val: the value for memory usage in gigabytes we want to print.
73-
- stage: The stage of the step we are in, either "peak" or "started".
74-
- step: The step that we performing in the script, either "compilation" or "inference".
75-
"""
76-
77-
if stage != "peak":
78-
if cpu_val is None or mem_val is None:
79-
timestamp_print(f"{step} {stage}")
80-
else:
81-
timestamp_print(f"{step} {stage} - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB")
82-
83-
elif cpu_val is not None and mem_val is not None:
84-
dprint(f"Peak Resource Utilization - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB")
8551

8652
def warmup_model(
8753
model: nn.Module,
@@ -114,9 +80,7 @@ def warmup_model(
11480
pt_compile_model_time = time.time()
11581

11682
## Report on initial resource usage
117-
metric_start = datetime.now(timezone.utc)
118-
initial_cpu, initial_mem = get_static_read(profile, metric_start)
119-
print_comp_resource_metrics(initial_cpu, initial_mem, "started", "Compilation")
83+
metric_start = print_step(profile, "started", "Compilation")
12084

12185
# adjust inputs depending on attn_type and dynamic shapes
12286
_warmup_input_ids = input_ids
@@ -148,14 +112,7 @@ def warmup_model(
148112
pt_compile_model_time = time.time() - pt_compile_model_time
149113

150114
# Get completed metric read
151-
metric_end = datetime.now(timezone.utc)
152-
end_cpu, end_mem = get_static_read(profile, metric_end)
153-
print_comp_resource_metrics(end_cpu, end_mem, "completed", "Compilation")
154-
155-
# Get the peak usage during compilation
156-
peak_cpu, peak_mem = get_peak_read(profile, metric_start, metric_end)
157-
print_comp_resource_metrics(peak_cpu, peak_mem, "peak", "Compilation")
158-
115+
print_step(profile, "completed", "Compilation", metric_start)
159116
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
160117

161118

aiu_fms_testing_utils/utils/resource_collection.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from datetime import datetime, timezone
44

5-
from aiu_fms_testing_utils.utils import print_comp_resource_metrics
5+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
66
try:
77
from prometheus_api_client import PrometheusConnect
88
except Exception as e:
@@ -142,18 +142,65 @@ def get_peak_read(client, start, end):
142142
return peak_cpu_value, peak_mem_value
143143

144144

145+
def timestamp_print(given_string):
146+
"""
147+
Helper method that will add a timestamp before the given string that needs to be
148+
printed.
149+
150+
Args:
151+
- given_string: the string that is to be printed with the timestamp.
152+
"""
153+
154+
timestamp = datetime.now().strftime("%Y-%m-%d:%H:%M:%S")
155+
print(f"[{timestamp}] {given_string}")
156+
157+
158+
def print_comp_resource_metrics(cpu_val, mem_val, stage, step):
159+
"""
160+
Helper method that will do a timestamp print for a specific step to report resource
161+
usage.
162+
163+
Args:
164+
- cpu_val: the value for CPU usage as a percentage that we want to print.
165+
- mem_val: the value for memory usage in gigabytes we want to print.
166+
- stage: The stage of the step we are in, either "peak" or "started".
167+
- step: The step that we performing in the script, either "compilation" or "inference".
168+
"""
169+
170+
if stage != "peak":
171+
if cpu_val is None or mem_val is None:
172+
timestamp_print(f"{step} {stage}")
173+
else:
174+
timestamp_print(f"{step} {stage} - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB")
175+
176+
elif cpu_val is not None and mem_val is not None:
177+
dprint(f"Peak Resource Utilization - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB")
178+
179+
145180
def print_step(p, step, stage, start_time=None):
146181
"""
182+
Print function to print out when a specific stage starts and ends,
183+
as well as reporting resource usage if enabled.
184+
185+
Args:
186+
- p: the Prometheus profile client to resource utilization collection.
187+
- step: string denoting what step we are at ("inference" or "compilation").
188+
- stage: string denoting what stage of the step we are at ("started" or "completed").
189+
- start_time: datetime object that denotes when the step started (optional).
190+
191+
Returns:
192+
- recorded_time: the time that was recorded when getting a metric read. Returned for
193+
scenarios where we need to use the recorded time in a later step (i.e completed stages).
147194
"""
148195

149196
## Get metric read
150-
timestep = datetime.now(timezone.utc)
151-
cpu_usage, mem_usage = get_static_read(p, timestep)
197+
recorded_time = datetime.now(timezone.utc)
198+
cpu_usage, mem_usage = get_static_read(p, recorded_time)
152199
print_comp_resource_metrics(cpu_usage, mem_usage, step, stage)
153200

154201
## Get and print the peak usage
155202
if start_time is not None:
156-
peak_cpu_inference_cpu, peak_mem_inference_cpu = get_peak_read(p, start_time, timestep)
203+
peak_cpu_inference_cpu, peak_mem_inference_cpu = get_peak_read(p, start_time, recorded_time)
157204
print_comp_resource_metrics(peak_cpu_inference_cpu, peak_mem_inference_cpu, "peak", stage)
158205

159-
return timestep
206+
return recorded_time

0 commit comments

Comments
 (0)