|
14 | 14 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size |
15 | 15 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
16 | 16 | from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string |
17 | | -from aiu_fms_testing_utils.utils.resource_collection import ( |
18 | | - get_static_read, get_peak_read |
19 | | -) |
| 17 | +from aiu_fms_testing_utils.utils.resource_collection import print_step |
20 | 18 | from fms.utils.generation import pad_input_ids |
21 | 19 | import torch |
22 | 20 | import torch.nn as nn |
@@ -50,38 +48,6 @@ def stagger_region(limit: int): |
50 | 48 | torch.distributed.barrier() |
51 | 49 | dprint("Stagger: All Complete") |
52 | 50 |
|
53 | | -def timestamp_print(given_string): |
54 | | - """ |
55 | | - Helper method that will add a timestamp before the given string that needs to be |
56 | | - printed. |
57 | | -
|
58 | | - Args: |
59 | | - - given_string: the string that is to be printed with the timestamp. |
60 | | - """ |
61 | | - |
62 | | - timestamp = datetime.now().strftime("%Y-%m-%d:%H:%M:%S") |
63 | | - print(f"[{timestamp}] {given_string}") |
64 | | - |
65 | | -def print_comp_resource_metrics(cpu_val, mem_val, stage, step): |
66 | | - """ |
67 | | - Helper method that will do a timestamp print for a specific step to report resource |
68 | | - usage. |
69 | | -
|
70 | | - Args: |
71 | | - - cpu_val: the value for CPU usage as a percentage that we want to print. |
72 | | - - mem_val: the value for memory usage in gigabytes we want to print. |
73 | | - - stage: The stage of the step we are in, either "peak" or "started". |
74 | | - - step: The step that we performing in the script, either "compilation" or "inference". |
75 | | - """ |
76 | | - |
77 | | - if stage != "peak": |
78 | | - if cpu_val is None or mem_val is None: |
79 | | - timestamp_print(f"{step} {stage}") |
80 | | - else: |
81 | | - timestamp_print(f"{step} {stage} - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB") |
82 | | - |
83 | | - elif cpu_val is not None and mem_val is not None: |
84 | | - dprint(f"Peak Resource Utilization - CPU: {cpu_val:.2f}%, Memory: {mem_val:.2f} GB") |
85 | 51 |
|
86 | 52 | def warmup_model( |
87 | 53 | model: nn.Module, |
@@ -114,9 +80,7 @@ def warmup_model( |
114 | 80 | pt_compile_model_time = time.time() |
115 | 81 |
|
116 | 82 | ## Report on initial resource usage |
117 | | - metric_start = datetime.now(timezone.utc) |
118 | | - initial_cpu, initial_mem = get_static_read(profile, metric_start) |
119 | | - print_comp_resource_metrics(initial_cpu, initial_mem, "started", "Compilation") |
| 83 | + metric_start = print_step(profile, "started", "Compilation") |
120 | 84 |
|
121 | 85 | # adjust inputs depending on attn_type and dynamic shapes |
122 | 86 | _warmup_input_ids = input_ids |
@@ -148,14 +112,7 @@ def warmup_model( |
148 | 112 | pt_compile_model_time = time.time() - pt_compile_model_time |
149 | 113 |
|
150 | 114 | # Get completed metric read |
151 | | - metric_end = datetime.now(timezone.utc) |
152 | | - end_cpu, end_mem = get_static_read(profile, metric_end) |
153 | | - print_comp_resource_metrics(end_cpu, end_mem, "completed", "Compilation") |
154 | | - |
155 | | - # Get the peak usage during compilation |
156 | | - peak_cpu, peak_mem = get_peak_read(profile, metric_start, metric_end) |
157 | | - print_comp_resource_metrics(peak_cpu, peak_mem, "peak", "Compilation") |
158 | | - |
| 115 | + print_step(profile, "completed", "Compilation", metric_start) |
159 | 116 | dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") |
160 | 117 |
|
161 | 118 |
|
|
0 commit comments