Skip to content

Commit 1e14343

Browse files
author
Jonathan Makunga
committed
Fix metric column name
1 parent 0bd6aa8 commit 1e14343

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,8 @@ def get_metrics_from_deployment_configs(
12201220
if not deployment_configs:
12211221
return {}
12221222

1223+
print("deployment_configs: {}".format(deployment_configs))
1224+
12231225
data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []}
12241226
instance_rate_data = {}
12251227
for index, deployment_config in enumerate(deployment_configs):
@@ -1256,26 +1258,27 @@ def get_metrics_from_deployment_configs(
12561258
instance_rate_data[instance_rate_column_name].append(instance_type_rate.value)
12571259

12581260
for metric in metrics:
1259-
column_name = _normalize_benchmark_metric_column_name(metric.name)
1261+
column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit)
12601262
data[column_name] = data.get(column_name, [])
12611263
data[column_name].append(metric.value)
12621264

12631265
data = {**data, **instance_rate_data}
12641266
return data
12651267

12661268

1267-
def _normalize_benchmark_metric_column_name(name: str) -> str:
1269+
def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str:
12681270
"""Normalizes benchmark metric column name.
12691271
12701272
Args:
12711273
name (str): Name of the metric.
1274+
unit (str): Unit of the metric.
12721275
Returns:
12731276
str: Normalized metric column name.
12741277
"""
12751278
if "latency" in name.lower():
1276-
name = "Latency for each user (TTFT in ms)"
1279+
name = f"Latency, TTFT (P50 in {unit.lower()})"
12771280
elif "throughput" in name.lower():
1278-
name = "Throughput per user (token/seconds)"
1281+
name = f"Throughput (P50 in {unit.lower()}/user)"
12791282
return name
12801283

12811284

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,14 +1979,14 @@ def test__normalize_benchmark_metrics():
19791979

19801980

19811981
@pytest.mark.parametrize(
1982-
"name, expected",
1982+
"name, unit, expected",
19831983
[
1984-
("latency", "Latency for each user (TTFT in ms)"),
1985-
("throughput", "Throughput per user (token/seconds)"),
1984+
("latency", "sec", "Latency, TTFT (P50 in sec)"),
1985+
("throughput", "tokens/sec", "Throughput (P50 in tokens/sec/user)"),
19861986
],
19871987
)
1988-
def test__normalize_benchmark_metric_column_name(name, expected):
1989-
out = utils._normalize_benchmark_metric_column_name(name)
1988+
def test_normalize_benchmark_metric_column_name(name, unit, expected):
1989+
out = utils._normalize_benchmark_metric_column_name(name, unit)
19901990

19911991
assert out == expected
19921992

0 commit comments

Comments
 (0)