From 20645c60af5eb0f7c11ec070c1e775efa9eea1da Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 11 Jul 2025 13:57:32 +0200 Subject: [PATCH 1/3] Richer conversational snippet for AutoModel --- .../tasks/src/model-libraries-snippets.ts | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 9782c3bfea..5f4e0546aa 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1373,27 +1373,58 @@ export const transformers = (model: ModelData): string[] => { } const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : ""; - let autoSnippet: string; + const autoSnippet = []; if (info.processor) { - const varName = + const processorVarName = info.processor === "AutoTokenizer" ? "tokenizer" : info.processor === "AutoFeatureExtractor" ? "extractor" : "processor"; - autoSnippet = [ + autoSnippet.push([ "# Load model directly", `from transformers import ${info.processor}, ${info.auto_model}`, "", - `${varName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", + `${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - ].join("\n"); + ]); + if (model.tags.includes("conversational")) { + if (model.tags.includes("image-text-to-text")) { + autoSnippet.push( + "messages = [", + [ + " {", + ' "role": "user",', + ' "content": [', + ' {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},', + ' {"type": "text", "text": "What animal is on the candy?"}', + " ]", + " },", + ].join("\n"), + "]" + ); + } else { + autoSnippet.push("messages = [", ' {"role": "user", "content": "Who are you?"},', "]"); + } + autoSnippet.push( + "inputs = ${processorVarName}.apply_chat_template(", + " messages,", + " add_generation_prompt=True,", + " tokenize=True,", + " return_dict=True,", + ' return_tensors="pt",', + ").to(model.device)", + "", + "outputs = model.generate(**inputs, max_new_tokens=40)", + 'print(${processorVarName}.decode(outputs[0][inputs["input_ids"].shape[-1]:]))', + ); + } } else { - autoSnippet = [ + autoSnippet.push([ "# Load model directly", `from transformers import ${info.auto_model}`, - `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - ].join("\n"); + `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ', torch_dtype="auto"),', + ]); } if (model.pipeline_tag && LIBRARY_TASK_MAPPING.transformers?.includes(model.pipeline_tag)) { @@ -1437,9 +1468,9 @@ export const transformers = (model: ModelData): string[] => { ); } - return [pipelineSnippet.join("\n"), autoSnippet]; + return [pipelineSnippet.join("\n"), autoSnippet.join("\n")]; } - return [autoSnippet]; + return [autoSnippet.join("\n")]; }; export const transformersJS = (model: ModelData): string[] => { From bd84b823c3c60c2decb107177455929dd8de70a2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 11 Jul 2025 14:18:30 +0200 Subject: [PATCH 2/3] Lint, fixes --- packages/tasks/src/model-libraries-snippets.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 5f4e0546aa..28d855deca 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1381,13 +1381,13 @@ export const transformers = (model: ModelData): string[] => { : info.processor === "AutoFeatureExtractor" ? "extractor" : "processor"; - autoSnippet.push([ + autoSnippet.push( "# Load model directly", `from transformers import ${info.processor}, ${info.auto_model}`, "", `${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - ]); + ); if (model.tags.includes("conversational")) { if (model.tags.includes("image-text-to-text")) { autoSnippet.push( @@ -1416,15 +1416,15 @@ export const transformers = (model: ModelData): string[] => { ").to(model.device)", "", "outputs = model.generate(**inputs, max_new_tokens=40)", - 'print(${processorVarName}.decode(outputs[0][inputs["input_ids"].shape[-1]:]))', + 'print(${processorVarName}.decode(outputs[0][inputs["input_ids"].shape[-1]:]))' ); } } else { - autoSnippet.push([ + autoSnippet.push( "# Load model directly", `from transformers import ${info.auto_model}`, - `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ', torch_dtype="auto"),', - ]); + `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ', torch_dtype="auto"),' + ); } if (model.pipeline_tag && LIBRARY_TASK_MAPPING.transformers?.includes(model.pipeline_tag)) { From 80e60375c606cf51f7380cbf6517f1c287810638 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 11 Jul 2025 14:20:17 +0200 Subject: [PATCH 3/3] lint trailing comma --- packages/tasks/src/model-libraries-snippets.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 28d855deca..39e973b511 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1386,7 +1386,7 @@ export const transformers = (model: ModelData): string[] => { `from transformers import ${info.processor}, ${info.auto_model}`, "", `${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", + `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")" ); if (model.tags.includes("conversational")) { if (model.tags.includes("image-text-to-text")) {