Skip to content

Commit 7c518d5

Browse files
committed
Rename skip_chat_template to use_chat_template across codebase
Chat templates are now opt-in everywhere. Fixes inconsistency between the API (which defaulted to skip) and the legacy extraction parser (which defaulted to apply). Updates README accordingly.
1 parent fdbf6ec commit 7c518d5

File tree

5 files changed

+31
-26
lines changed

5 files changed

+31
-26
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ calc-dna \
9191

9292
- **Metadata auto-fetched**: Model metadata is automatically retrieved from HuggingFace Hub and cached.
9393
- **Auth token**: Pass via `token=...` or set `HF_TOKEN` environment variable.
94-
- **Chat templates**: Applied automatically when supported by the tokenizer.
94+
- **Chat templates**: Disabled by default. Enable with `--use-chat-template` (CLI) or `use_chat_template=True` (API).
9595

9696
## Tests
9797

src/llm_dna/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DNAExtractionConfig:
4949
gpu_id: Optional[int] = None
5050
log_level: str = "INFO"
5151
random_seed: int = 42
52-
skip_chat_template: bool = True
52+
use_chat_template: bool = False
5353

5454

5555
@dataclass(slots=True)
@@ -413,7 +413,7 @@ def _save_response_incrementally(idx: int, prompt: str, response: str) -> None:
413413
temperature=0.0,
414414
do_sample=False,
415415
top_p=1.0,
416-
skip_chat_template=config.skip_chat_template,
416+
use_chat_template=config.use_chat_template,
417417
on_response_callback=_save_response_incrementally if incremental_save_path else None,
418418
)
419419
else:
@@ -425,7 +425,7 @@ def _save_response_incrementally(idx: int, prompt: str, response: str) -> None:
425425
temperature=0.0,
426426
do_sample=False,
427427
top_p=1.0,
428-
skip_chat_template=config.skip_chat_template,
428+
use_chat_template=config.use_chat_template,
429429
)
430430
responses.append(response)
431431
if incremental_save_path:
@@ -628,7 +628,7 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult:
628628
device=resolved_device,
629629
log_level=config.log_level,
630630
random_seed=config.random_seed,
631-
skip_chat_template=config.skip_chat_template,
631+
use_chat_template=config.use_chat_template,
632632
)
633633

634634
signature = core.extract_dna_signature(

src/llm_dna/cli.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@ def parse_arguments(argv: Optional[Iterable[str]] = None) -> argparse.Namespace:
169169
default="INFO",
170170
)
171171
parser.add_argument("--random-seed", type=int, default=42)
172-
parser.add_argument("--skip-chat-template", action="store_true")
172+
parser.add_argument(
173+
"--use-chat-template",
174+
action="store_true",
175+
default=False,
176+
help="Apply chat template for HuggingFace models (default: disabled).",
177+
)
173178

174179
return parser.parse_args(list(argv) if argv is not None else None)
175180

@@ -221,7 +226,7 @@ def main(argv: Optional[Iterable[str]] = None) -> int:
221226
gpu_id=None,
222227
log_level=args.log_level,
223228
random_seed=args.random_seed,
224-
skip_chat_template=args.skip_chat_template,
229+
use_chat_template=args.use_chat_template,
225230
)
226231
try:
227232
results = calc_dna_parallel(
@@ -276,7 +281,7 @@ def main(argv: Optional[Iterable[str]] = None) -> int:
276281
gpu_id=gpu_id,
277282
log_level=args.log_level,
278283
random_seed=args.random_seed,
279-
skip_chat_template=args.skip_chat_template,
284+
use_chat_template=args.use_chat_template,
280285
)
281286

282287
result = calc_dna(config)

src/llm_dna/core/extraction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
206206
help="Random seed for reproducibility"
207207
)
208208
parser.add_argument(
209-
"--skip-chat-template",
209+
"--use-chat-template",
210210
action="store_true",
211211
default=False,
212-
help="Skip applying chat templates for chat models (treat them as completion models). By default, chat templates are applied for embedding extractor."
212+
help="Apply chat templates for chat models. By default, chat templates are not applied."
213213
)
214214

215215
return parser.parse_args(argv)
@@ -384,7 +384,7 @@ def extract_dna_signature(
384384
"""Extract DNA signature from model."""
385385

386386
# Apply chat template when available for chat-oriented tokenizers.
387-
if extractor_type == "embedding" and not args.skip_chat_template:
387+
if extractor_type == "embedding" and args.use_chat_template:
388388
is_chat_model = model_metadata.get("chat_model", {}).get("is_chat_model", False)
389389
should_try_template = is_chat_model or "chat_model" not in model_metadata
390390
try:

src/llm_dna/models/ModelWrapper.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def generate(
632632
temperature: float = 0.7,
633633
do_sample: bool = True,
634634
top_p: float = 0.9,
635-
skip_chat_template: bool = False,
635+
use_chat_template: bool = False,
636636
**kwargs
637637
) -> str:
638638
"""Generate text from input, respecting the model's context length."""
@@ -648,10 +648,10 @@ def generate(
648648
safe_input_length, max_new_tokens = self._get_safe_generation_params(max_length)
649649

650650
# Prefer chat-template tokenization when available to ensure special tokens are handled
651-
# Skip chat template if skip_chat_template=True (treat chat models as completion models)
651+
# Apply chat template only when use_chat_template=True
652652
inputs = None
653653
prefers_chat_template = False
654-
if not skip_chat_template:
654+
if use_chat_template:
655655
if self.is_chat_model is True:
656656
prefers_chat_template = True
657657
# Heuristic fallback if metadata wasn't provided
@@ -757,7 +757,7 @@ def generate(
757757
# Decode only the newly generated tokens
758758
new_tokens = outputs[0][input_length:]
759759
# When skipping chat template, preserve special tokens (match try_chat_model_without_template.py behavior)
760-
skip_special_tokens = not skip_chat_template
760+
skip_special_tokens = use_chat_template
761761
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=skip_special_tokens)
762762

763763
return generated_text.strip()
@@ -1035,9 +1035,9 @@ def __init__(
10351035
except Exception as e:
10361036
raise RuntimeError(f"Failed to initialize vLLM engine: {e}")
10371037

1038-
def _format_prompt(self, user_text: str, skip_chat_template: bool = False) -> str:
1039-
# Prefer chat template for chat models, unless skip_chat_template=True
1040-
if skip_chat_template:
1038+
def _format_prompt(self, user_text: str, use_chat_template: bool = False) -> str:
1039+
# Apply chat template for chat models only when use_chat_template=True
1040+
if not use_chat_template:
10411041
return user_text
10421042
try:
10431043
prefers_chat = False
@@ -1062,12 +1062,12 @@ def generate(
10621062
temperature: float = 0.7,
10631063
do_sample: bool = True,
10641064
top_p: float = 0.9,
1065-
skip_chat_template: bool = False,
1065+
use_chat_template: bool = False,
10661066
**kwargs
10671067
) -> str:
10681068
try:
10691069
from vllm import SamplingParams
1070-
prompt = self._format_prompt(input_text, skip_chat_template=skip_chat_template)
1070+
prompt = self._format_prompt(input_text, use_chat_template=use_chat_template)
10711071
# Map our "max_length" contract to vLLM's max_tokens for new tokens
10721072
# Our safe length logic is in HF wrapper; here we approximate with max_tokens
10731073
params = SamplingParams(
@@ -1091,7 +1091,7 @@ def generate_batch(
10911091
temperature: float = 0.7,
10921092
do_sample: bool = True,
10931093
top_p: float = 0.9,
1094-
skip_chat_template: bool = False,
1094+
use_chat_template: bool = False,
10951095
**kwargs
10961096
) -> List[str]:
10971097
"""Generate for a list of prompts in one vLLM call.
@@ -1103,7 +1103,7 @@ def generate_batch(
11031103
return []
11041104
try:
11051105
from vllm import SamplingParams
1106-
formatted = [self._format_prompt(p, skip_chat_template=skip_chat_template) for p in prompts]
1106+
formatted = [self._format_prompt(p, use_chat_template=use_chat_template) for p in prompts]
11071107
params = SamplingParams(
11081108
max_tokens=max_length,
11091109
temperature=temperature,
@@ -1122,7 +1122,7 @@ def generate_batch(
11221122
self.logger.error(f"vLLM batch generation failed: {e}")
11231123
# Fall back to sequential to salvage outputs
11241124
return [
1125-
self.generate(p, max_length=max_length, temperature=temperature, do_sample=do_sample, top_p=top_p, skip_chat_template=skip_chat_template, **kwargs)
1125+
self.generate(p, max_length=max_length, temperature=temperature, do_sample=do_sample, top_p=top_p, use_chat_template=use_chat_template, **kwargs)
11261126
for p in prompts
11271127
]
11281128

@@ -1383,7 +1383,7 @@ def generate_batch(
13831383
completion_window = str(kwargs.pop("batch_completion_window", "24h"))
13841384

13851385
# Wrapper-only kwargs should not be forwarded to the provider payload.
1386-
kwargs.pop("skip_chat_template", None)
1386+
kwargs.pop("use_chat_template", None)
13871387

13881388
if not prefer_batch_api:
13891389
return super().generate_batch(
@@ -1979,7 +1979,7 @@ def generate(
19791979
top_p: float = 0.9,
19801980
**kwargs
19811981
) -> str:
1982-
kwargs.pop("skip_chat_template", None)
1982+
kwargs.pop("use_chat_template", None)
19831983
payload = {
19841984
"contents": [{"role": "user", "parts": [{"text": input_text}]}],
19851985
"generationConfig": self._build_gemini_generation_config(
@@ -2022,7 +2022,7 @@ def generate_batch(
20222022
timeout = kwargs.pop("batch_timeout_seconds", self.batch_timeout_seconds)
20232023

20242024
# Wrapper-only kwarg
2025-
kwargs.pop("skip_chat_template", None)
2025+
kwargs.pop("use_chat_template", None)
20262026

20272027
if not prefer_batch_api:
20282028
return super().generate_batch(

0 commit comments

Comments
 (0)