Skip to content

Commit dcf8615

Browse files
committed
fix: harden max_concurrency with validation, autoscale coverage, and tests
- to_yaml() now includes max_concurrency when set (was silently dropped) - Autoscale path (pick_autoscale) now filters by max_concurrency - Validate max_concurrency >= 1 in TaskConfig (raises ValueError) - Add INFO-level log when max_concurrency constraint is active - Update find_best_disagg_result_under_constraints docstring - Add tests: validation rejects 0/-5, to_yaml round-trip with/without Signed-off-by: Yimingl <yimingl@nvidia.com>
1 parent 1f24b4c commit dcf8615

File tree

5 files changed

+72
-0
lines changed

5 files changed

+72
-0
lines changed

src/aiconfigurator/sdk/inference_session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def _pick_autoscale(
406406
target_ttft: float | None = None,
407407
target_tpot: float | None = None,
408408
top_n: int = 5,
409+
max_concurrency: int | None = None,
409410
) -> InferenceSummary:
410411
"""Pick best prefill and decode engines independently for autoscaling.
411412
@@ -427,6 +428,7 @@ def _pick_autoscale(
427428
target_ttft=target_ttft,
428429
target_tpot=target_tpot,
429430
top_n=top_n,
431+
max_concurrency=max_concurrency,
430432
)
431433

432434
disagg_summary_df = result["best_config_df"]
@@ -478,6 +480,13 @@ def find_best_disagg_result_under_constraints(
478480
decode_max_num_tokens (int): the decode max num tokens
479481
decode_num_worker_list (List[int]): the decode num worker list
480482
num_gpu_list (Optional[List[int]]): the num gpu list
483+
require_same_tp (bool): require same TP for prefill and decode
484+
autoscale (bool): use autoscale picking (P and D chosen independently)
485+
target_tpot (Optional[float]): TPOT target for autoscale mode
486+
max_concurrency (Optional[int]): maximum global concurrency.
487+
Compositions whose ``concurrency`` exceeds this value are
488+
excluded from the search in both rate-matching and autoscale
489+
paths.
481490
482491
Returns:
483492
Optional[InferenceSummary]: the summary of the inference result, contains all the
@@ -710,6 +719,7 @@ def _find_best_result_under_constraints(
710719
runtime_config=runtime_config,
711720
disagg_summary=disagg_summary,
712721
target_tpot=target_tpot,
722+
max_concurrency=max_concurrency,
713723
)
714724

715725
# find best result under constraints

src/aiconfigurator/sdk/pareto_analysis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def agg_pareto(
5151
results_df: dataframe of the results
5252
"""
5353

54+
if max_concurrency is not None:
55+
logger.info("agg_pareto: max_concurrency=%d is active; capping batch-size sweep per config", max_concurrency)
56+
5457
# agg is agg server, the loop over parallel is outside here.
5558
results_df = pd.DataFrame(columns=ColumnsAgg)
5659
exceptions = []
@@ -284,6 +287,8 @@ def get_working_list(working_list, max_constraint):
284287
autoscale = kwargs.get("autoscale", False)
285288
target_tpot = kwargs.get("target_tpot")
286289
max_concurrency = kwargs.get("max_concurrency")
290+
if max_concurrency is not None:
291+
logger.info("disagg_pareto: max_concurrency=%d is active; filtering compositions", max_concurrency)
287292

288293
summary = disagg_sess.find_best_disagg_result_under_constraints(
289294
model_path=model_path,

src/aiconfigurator/sdk/picking.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def pick_autoscale(
373373
target_ttft: float,
374374
target_tpot: float,
375375
top_n: int = 5,
376+
max_concurrency: int | None = None,
376377
) -> dict[str, Any]:
377378
"""Pick prefill and decode engines independently for autoscaling.
378379
@@ -462,6 +463,8 @@ def pick_autoscale(
462463
decode_summary_dict=d_row.to_dict(),
463464
decode_num_worker=1,
464465
)
466+
if max_concurrency is not None and combo["concurrency"] > max_concurrency:
467+
continue
465468
all_combos.append(combo)
466469

467470
if not all_combos:

src/aiconfigurator/sdk/task.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ def __init__(
638638
effective_profiles = list(dict.fromkeys([*effective_profiles, *yaml_profiles]))
639639
yaml_patch = yaml_config.get("config", yaml_config)
640640

641+
if max_concurrency is not None and max_concurrency < 1:
642+
raise ValueError(f"max_concurrency must be >= 1, got {max_concurrency}")
643+
641644
ctx = TaskContext(
642645
serving_mode=serving_mode,
643646
model_path=model_path,
@@ -911,6 +914,8 @@ def _convert(obj: Any) -> Any:
911914
)
912915

913916
printable["enable_wideep"] = self.enable_wideep
917+
if self.max_concurrency is not None:
918+
printable["max_concurrency"] = self.max_concurrency
914919
printable["moe_backend"] = self.config.moe_backend
915920
printable["attention_backend"] = self.config.attention_backend
916921

tests/unit/sdk/task/test_task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,52 @@ def test_agg_max_concurrency_none_by_default(monkeypatch):
832832
TaskRunner().run(task)
833833

834834
assert captured.get("max_concurrency") is None
835+
836+
837+
def test_taskconfig_max_concurrency_zero_rejected():
838+
"""max_concurrency=0 should raise ValueError."""
839+
with pytest.raises(ValueError, match=r"max_concurrency must be >= 1"):
840+
TaskConfig(
841+
serving_mode="agg",
842+
model_path="Qwen/Qwen3-32B",
843+
system_name="h200_sxm",
844+
max_concurrency=0,
845+
)
846+
847+
848+
def test_taskconfig_max_concurrency_negative_rejected():
849+
"""Negative max_concurrency should raise ValueError."""
850+
with pytest.raises(ValueError, match=r"max_concurrency must be >= 1"):
851+
TaskConfig(
852+
serving_mode="agg",
853+
model_path="Qwen/Qwen3-32B",
854+
system_name="h200_sxm",
855+
max_concurrency=-5,
856+
)
857+
858+
859+
def test_taskconfig_to_yaml_includes_max_concurrency():
860+
"""to_yaml() must include max_concurrency when it is set."""
861+
task = TaskConfig(
862+
serving_mode="agg",
863+
model_path="Qwen/Qwen3-32B",
864+
system_name="h200_sxm",
865+
max_concurrency=256,
866+
)
867+
yaml_output = task.to_yaml()
868+
parsed = yaml.safe_load(yaml_output)
869+
task_name = task.task_name
870+
assert parsed[task_name]["max_concurrency"] == 256
871+
872+
873+
def test_taskconfig_to_yaml_omits_max_concurrency_when_none():
874+
"""to_yaml() must not include max_concurrency when it is None."""
875+
task = TaskConfig(
876+
serving_mode="agg",
877+
model_path="Qwen/Qwen3-32B",
878+
system_name="h200_sxm",
879+
)
880+
yaml_output = task.to_yaml()
881+
parsed = yaml.safe_load(yaml_output)
882+
task_name = task.task_name
883+
assert "max_concurrency" not in parsed[task_name]

0 commit comments

Comments
 (0)