Skip to content

Commit 389910f

Browse files
committed
llmapi: reduce fallback repetition and harden support-matrix sync
Centralize model feature fallback disabling and standardize the warning format. Mark supported-models.md as generated and add basic generator/data invariants to prevent silent drift. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
1 parent e4f0a5a commit 389910f

File tree

4 files changed

+107
-36
lines changed

4 files changed

+107
-36
lines changed

docs/source/models/supported-models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
(support-matrix)=
2+
<!-- Generated from tensorrt_llm/llmapi/model_support_matrix.py; do not edit. -->
23
# Supported Models
34

45
The following is a table of supported models for the PyTorch backend:

tensorrt_llm/llmapi/llm.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -261,44 +261,61 @@ def _apply_model_feature_fallbacks(self) -> None:
261261
return
262262
arch = archs[0]
263263

264+
def _disable_if_unsupported(
265+
feature: SupportFeature,
266+
*,
267+
enabled: bool,
268+
arg_path: str,
269+
disable,
270+
) -> None:
271+
# Preserve behavior: only override when user explicitly enabled it.
272+
if not enabled:
273+
return
274+
status = get_support_status(arch, feature)
275+
# Preserve behavior: unknown/untested/missing status must not disable anything.
276+
if status not in (SupportStatus.NO, SupportStatus.NA):
277+
return
278+
logger.warning(
279+
f"{arch}: {feature.value} unsupported; disabling {arg_path}")
280+
disable()
281+
264282
kv_cfg = getattr(self.args, "kv_cache_config", None)
265-
if kv_cfg is not None and getattr(kv_cfg, "enable_block_reuse", False):
266-
if get_support_status(
267-
arch, SupportFeature.KV_CACHE_REUSE) in (SupportStatus.NO,
268-
SupportStatus.NA):
269-
logger.warning(
270-
f"{arch}: KV cache reuse unsupported; setting kv_cache_config.enable_block_reuse=False"
271-
)
283+
284+
def _disable_kv_cache_reuse() -> None:
285+
if kv_cfg is not None:
272286
kv_cfg.enable_block_reuse = False
273287

274-
if getattr(self.args, "enable_chunked_prefill", False):
275-
if get_support_status(
276-
arch, SupportFeature.CHUNKED_PREFILL) in (SupportStatus.NO,
277-
SupportStatus.NA):
278-
logger.warning(
279-
f"{arch}: Chunked prefill unsupported; setting enable_chunked_prefill=False"
280-
)
281-
self.args.enable_chunked_prefill = False
282-
283-
if getattr(self.args, "enable_attention_dp", False):
284-
if get_support_status(
285-
arch, SupportFeature.ATTENTION_DP) in (SupportStatus.NO,
286-
SupportStatus.NA):
287-
logger.warning(
288-
f"{arch}: Attention DP unsupported; setting enable_attention_dp=False"
289-
)
290-
self.args.enable_attention_dp = False
291-
292-
if hasattr(self.args, "disable_overlap_scheduler") and getattr(
293-
self.args, "disable_overlap_scheduler") is False:
294-
if get_support_status(
295-
arch,
296-
SupportFeature.OVERLAP_SCHEDULER) in (SupportStatus.NO,
297-
SupportStatus.NA):
298-
logger.warning(
299-
f"{arch}: Overlap scheduler unsupported; setting disable_overlap_scheduler=True"
300-
)
301-
self.args.disable_overlap_scheduler = True
288+
_disable_if_unsupported(
289+
SupportFeature.KV_CACHE_REUSE,
290+
enabled=kv_cfg is not None
291+
and getattr(kv_cfg, "enable_block_reuse", False),
292+
arg_path="kv_cache_config.enable_block_reuse",
293+
disable=_disable_kv_cache_reuse,
294+
)
295+
296+
_disable_if_unsupported(
297+
SupportFeature.CHUNKED_PREFILL,
298+
enabled=getattr(self.args, "enable_chunked_prefill", False),
299+
arg_path="enable_chunked_prefill",
300+
disable=lambda: setattr(self.args, "enable_chunked_prefill", False),
301+
)
302+
303+
_disable_if_unsupported(
304+
SupportFeature.ATTENTION_DP,
305+
enabled=getattr(self.args, "enable_attention_dp", False),
306+
arg_path="enable_attention_dp",
307+
disable=lambda: setattr(self.args, "enable_attention_dp", False),
308+
)
309+
310+
# disable_overlap_scheduler is inverted: we only flip it when currently False.
311+
_disable_if_unsupported(
312+
SupportFeature.OVERLAP_SCHEDULER,
313+
enabled=hasattr(self.args, "disable_overlap_scheduler")
314+
and getattr(self.args, "disable_overlap_scheduler") is False,
315+
arg_path="disable_overlap_scheduler",
316+
disable=lambda: setattr(self.args, "disable_overlap_scheduler", True
317+
),
318+
)
302319

303320
@property
304321
@set_api_status("beta")

tensorrt_llm/llmapi/model_support_matrix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def render_supported_models_markdown() -> str:
500500
"""Render the full `docs/source/models/supported-models.md` content."""
501501
out: List[str] = []
502502
out.append("(support-matrix)=")
503+
out.append("<!-- Generated from tensorrt_llm/llmapi/model_support_matrix.py; do not edit. -->")
503504
out.append("# Supported Models")
504505
out.append("")
505506
out.append("The following is a table of supported models for the PyTorch backend:")

tests/unittest/tools/test_supported_models_sync.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121

2222

2323
def _render_supported_models_markdown(repo_root: Path) -> str:
24+
module = _load_model_support_matrix_module(repo_root)
25+
return module.render_supported_models_markdown()
26+
27+
28+
def _load_model_support_matrix_module(repo_root: Path):
2429
module_path = repo_root / "tensorrt_llm/llmapi/model_support_matrix.py"
2530
spec = importlib.util.spec_from_file_location("tllm_model_support_matrix", module_path)
2631
if spec is None or spec.loader is None:
@@ -30,7 +35,7 @@ def _render_supported_models_markdown(repo_root: Path) -> str:
3035
# Needed for dataclasses/type evaluation during module exec.
3136
sys.modules[spec.name] = module
3237
spec.loader.exec_module(module)
33-
return module.render_supported_models_markdown()
38+
return module
3439

3540

3641
class TestSupportedModelsSync(unittest.TestCase):
@@ -50,6 +55,53 @@ def test_supported_models_md_sync(self):
5055
"Please regenerate it (e.g. build docs, or run the generator entrypoint in docs/source/helper.py).",
5156
)
5257

58+
def test_supported_models_matrix_invariants(self):
59+
"""Catch support-matrix drift early (ordering, duplication, footnotes)."""
60+
repo_root = Path(__file__).resolve().parents[3]
61+
module = _load_model_support_matrix_module(repo_root)
62+
63+
self.assertEqual(
64+
set(module.KEY_MODEL_ARCH_ORDER),
65+
set(module.KEY_MODEL_MATRIX.keys()),
66+
"KEY_MODEL_ARCH_ORDER must match KEY_MODEL_MATRIX keys (no missing/extra rows).",
67+
)
68+
self.assertEqual(
69+
set(module.MULTIMODAL_ARCH_ORDER),
70+
set(module.MULTIMODAL_MATRIX.keys()),
71+
"MULTIMODAL_ARCH_ORDER must match MULTIMODAL_MATRIX keys (no missing/extra rows).",
72+
)
73+
74+
self.assertEqual(
75+
len(module.KEY_MODEL_ARCH_ORDER),
76+
len(set(module.KEY_MODEL_ARCH_ORDER)),
77+
"KEY_MODEL_ARCH_ORDER contains duplicate architectures.",
78+
)
79+
self.assertEqual(
80+
len(module.MULTIMODAL_ARCH_ORDER),
81+
len(set(module.MULTIMODAL_ARCH_ORDER)),
82+
"MULTIMODAL_ARCH_ORDER contains duplicate architectures.",
83+
)
84+
85+
architectures = [m.architecture for m in module.SUPPORTED_MODELS_PYTORCH]
86+
self.assertEqual(
87+
len(architectures),
88+
len(set(architectures)),
89+
"SUPPORTED_MODELS_PYTORCH contains duplicate architectures.",
90+
)
91+
92+
used_footnotes = set()
93+
for row in module.KEY_MODEL_MATRIX.values():
94+
for cell in row.values():
95+
footnote = getattr(cell, "footnote", None)
96+
if footnote:
97+
used_footnotes.add(footnote)
98+
99+
for fn in used_footnotes:
100+
self.assertTrue(
101+
any(note.startswith(f"{fn}:") for note in module.KEY_MODEL_FOOTNOTES),
102+
f"Missing footnote definition for {fn} in KEY_MODEL_FOOTNOTES.",
103+
)
104+
53105

54106
if __name__ == "__main__":
55107
unittest.main()

0 commit comments

Comments
 (0)