Skip to content
709 changes: 330 additions & 379 deletions benchmarks/profiler/profile_sla.py

Large diffs are not rendered by default.

68 changes: 1 addition & 67 deletions benchmarks/profiler/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import math
import shlex
from typing import Literal, Optional, Protocol
from typing import Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -378,69 +378,3 @@ def update_image(config: dict, image: str) -> dict:
logger.debug(f"Updated image for {service_name} to {image}")

return cfg.model_dump()


class ConfigModifierProtocol(Protocol):
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
...

@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def get_model_name(cls, config: dict) -> str:
...

@classmethod
def get_port(cls, config: dict) -> int:
...

@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
...

@classmethod
def load_default_config(cls) -> dict:
...

@classmethod
def update_model(cls, config: dict, model_name: str) -> dict:
...

@classmethod
def update_image(cls, config: dict, image: str) -> dict:
...
4 changes: 3 additions & 1 deletion benchmarks/profiler/utils/config_modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from benchmarks.profiler.utils.config import ConfigModifierProtocol
from benchmarks.profiler.utils.config_modifiers.protocol import (
ConfigModifierProtocol,
)

from benchmarks.profiler.utils.config_modifiers.sglang import SGLangConfigModifier
from benchmarks.profiler.utils.config_modifiers.trtllm import TrtllmConfigModifier
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import copy
import logging
from dataclasses import dataclass
from enum import Enum

from benchmarks.profiler.utils.model_info import ModelInfo

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


class ParallelizationStrategy(Enum):
"""Enum for parallelization strategy types."""

TP = "TP"
TEP = "TEP"
DEP = "DEP"


@dataclass(frozen=True)
class ParallelizationMapping:
"""
Represents parallelization mapping of configs
"""

tp: int | None = None
tep: int | None = None
dep: int | None = None

def label(self) -> str:
if self.tp is not None:
return f"{ParallelizationStrategy.TP.value}={self.tp}"
if self.tep is not None:
return f"{ParallelizationStrategy.TEP.value}={self.tep}"
if self.dep is not None:
return f"{ParallelizationStrategy.DEP.value}={self.dep}"
return "default"

def get_tp_size(self) -> int:
"""
Get the effective TP size for KV heads splitting.
Both TP and TEP split KV heads, DEP doesn't (returns 1).
"""
if self.tp is not None:
return self.tp
if self.tep is not None:
return self.tep
return 1 # DEP has TP split of 1

def get_expert_split(self) -> int:
"""
Get the effective expert split size.
Both TEP and DEP split experts, TP doesn't (returns 1).
"""
if self.tep is not None:
return self.tep
if self.dep is not None:
return self.dep
return 1 # TP has expert split of 1


def _check_divisibility(
value: int | None,
divisor: int,
value_name: str,
divisor_name: str,
mapping_label: str,
) -> bool:
"""
Check if value is divisible by divisor.
Returns True if valid (or value is None), False if invalid.
Args:
value: The value to check (e.g., num_kv_heads, num_experts)
divisor: The divisor to check against
value_name: Name of the value for error messages
divisor_name: Name of the divisor for error messages (e.g., "tp_size", "expert_split")
mapping_label: Label of the mapping for error messages
"""
if value is None:
logger.warning(
f"Skipping {value_name} divisibility check for {mapping_label}: {value_name} is unknown"
)
return True

if divisor > 1 and int(value) % divisor != 0:
logger.warning(
f"Invalid mapping {mapping_label}: {value_name}={value} not divisible by {divisor_name}={divisor}"
)
return False

return True


def _validate_intermediate_size(
mapping: ParallelizationMapping,
intermediate_size: int | None,
quant_block: int | None,
) -> bool:
"""
Validate intermediate size and quantization block for TP and TEP strategies.
Checks:
- intermediate_size % tp_size == 0
- (intermediate_size // tp_size) divides quant_block (if quant_block is known)
"""
tp_size = mapping.get_tp_size()

# Check basic divisibility
if not _check_divisibility(
intermediate_size, tp_size, "intermediate_size", "tp_size", mapping.label()
):
return False

# Additional check for quantization block constraint
if intermediate_size is not None and quant_block is not None and tp_size > 1:
per_shard = int(intermediate_size) // tp_size
if not _check_divisibility(
per_shard, quant_block, "per_shard", "quant_block", mapping.label()
):
return False

return True


def get_candidate_parallel_mappings(
num_gpus: int, model_info: ModelInfo, phase: str
) -> list[ParallelizationMapping]:
"""
Return a list of candidate parallelization mappings for a given GPU count and phase,
verified against model properties.
Verification rules:
- TP and TEP must divide num_kv_heads (if available)
- TEP and DEP must divide num_experts (if available)
"""
is_moe = bool(model_info.is_moe)
num_kv_heads = model_info.num_kv_heads
num_experts = model_info.num_experts
intermediate_size = model_info.intermediate_size
quant_block = model_info.quantization_block_size

candidates: list[ParallelizationMapping] = []
if is_moe:
if phase == "prefill":
candidates = [ParallelizationMapping(tep=num_gpus)]
elif phase == "decode":
candidates = [ParallelizationMapping(dep=num_gpus)]
else:
candidates = [ParallelizationMapping(tp=num_gpus)]

# Verify candidates against model constraints
verified: list[ParallelizationMapping] = []
for m in candidates:
# Check KV heads divisibility
if not _check_divisibility(
num_kv_heads, m.get_tp_size(), "num_kv_heads", "tp_size", m.label()
):
continue

# Check experts divisibility
if not _check_divisibility(
num_experts, m.get_expert_split(), "num_experts", "expert_split", m.label()
):
continue

# Check intermediate size and quantization block
if not _validate_intermediate_size(m, intermediate_size, quant_block):
continue

verified.append(m)

return verified


def apply_parallel_mapping_to_config(
base_config: dict,
mapping: ParallelizationMapping,
phase: str,
config_modifier,
num_gpus_per_node: int | None,
) -> dict:
cfg = copy.deepcopy(base_config)
if mapping.tp is not None:
cfg = config_modifier.set_config_tp_size(cfg, mapping.tp)
elif phase == "prefill" and mapping.tep is not None:
cfg = config_modifier.set_config_tep_size(cfg, mapping.tep, num_gpus_per_node)
elif phase == "decode" and mapping.dep is not None:
cfg = config_modifier.set_config_dep_size(cfg, mapping.dep, num_gpus_per_node)
else:
pass
return cfg
84 changes: 84 additions & 0 deletions benchmarks/profiler/utils/config_modifiers/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal, Protocol

from dynamo.planner.defaults import SubComponentType


class ConfigModifierProtocol(Protocol):
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
...

@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
) -> dict:
...

@classmethod
def get_model_name(cls, config: dict) -> str:
...

@classmethod
def get_port(cls, config: dict) -> int:
...

@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
...

@classmethod
def load_default_config(cls) -> dict:
...

@classmethod
def update_model(cls, config: dict, model_name: str) -> dict:
...

@classmethod
def update_image(cls, config: dict, image: str) -> dict:
...
6 changes: 6 additions & 0 deletions benchmarks/profiler/utils/config_modifiers/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ def get_model_name(cls, config: dict) -> str:
return DEFAULT_MODEL_NAME

args = break_arguments(args)
# Check for --model-path first (primary argument for SGLang)
for i, arg in enumerate(args):
if arg == "--model-path" and i + 1 < len(args):
return args[i + 1]

# Fall back to --served-model-name if --model-path not found
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
Expand Down
Loading
Loading