@@ -4,6 +4,10 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44import { getModelInputSnippet } from "./inputs.js" ;
55import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
66
7+ const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
8+ `from huggingface_hub import InferenceClient
9+ client = InferenceClient("${ model . id } ", token="${ accessToken || "{API_TOKEN}" } ")` ;
10+
711export const snippetConversational = (
812 model : ModelDataMinimal ,
913 accessToken : string ,
@@ -161,8 +165,16 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
161165output = query(${ getModelInputSnippet ( model ) } )` ,
162166} ) ;
163167
164- export const snippetTextToImage = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
165- content : `def query(payload):
168+ export const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => [
169+ {
170+ client : "huggingface_hub" ,
171+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
172+ # output is a PIL.Image object
173+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
174+ } ,
175+ {
176+ client : "requests" ,
177+ content : `def query(payload):
166178 response = requests.post(API_URL, headers=headers, json=payload)
167179 return response.content
168180image_bytes = query({
@@ -172,7 +184,8 @@ image_bytes = query({
172184import io
173185from PIL import Image
174186image = Image.open(io.BytesIO(image_bytes))` ,
175- } ) ;
187+ } ,
188+ ] ;
176189
177190export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
178191 content : `def query(payload):
@@ -288,12 +301,14 @@ export function getPythonInferenceSnippet(
288301 return snippets . map ( ( snippet ) => {
289302 return {
290303 ...snippet ,
291- content : `import requests
304+ content : snippet . content . includes ( "requests" )
305+ ? `import requests
292306
293307API_URL = "https://api-inference.huggingface.co/models/${ model . id } "
294308headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
295309
296- ${ snippet . content } `,
310+ ${ snippet . content } `
311+ : snippet . content ,
297312 } ;
298313 } ) ;
299314 }
0 commit comments