@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44import { getModelInputSnippet } from "./inputs.js" ;
55import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
66
7+ const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
8+ `from huggingface_hub import InferenceClient
9+
10+ client = InferenceClient(${ model . id } , token="${ accessToken || "{API_TOKEN}" } ")` ;
11+
712export const snippetConversational = (
813 model : ModelDataMinimal ,
914 accessToken : string ,
@@ -184,18 +189,31 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
184189output = query(${ getModelInputSnippet ( model ) } )` ,
185190} ) ;
186191
187- export const snippetTextToImage = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
188- content : `def query(payload):
192+ export const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => {
193+ return [
194+ {
195+ client : "requests" ,
196+ content : `def query(payload):
189197 response = requests.post(API_URL, headers=headers, json=payload)
190198 return response.content
199+
191200image_bytes = query({
192201 "inputs": ${ getModelInputSnippet ( model ) } ,
193202})
194203# You can access the image with PIL.Image for example
195204import io
196205from PIL import Image
197206image = Image.open(io.BytesIO(image_bytes))` ,
198- } ) ;
207+ } ,
208+ {
209+ client : "huggingface_hub" ,
210+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
211+
212+ # output is a PIL.Image object
213+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
214+ } ,
215+ ] ;
216+ } ;
199217
200218export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
201219 content : `def query(payload):
@@ -300,6 +318,9 @@ export function getPythonInferenceSnippet(
300318 if ( model . tags . includes ( "conversational" ) ) {
301319 // Conversational model detected, so we display a code snippet that features the Messages API
302320 return snippetConversational ( model , accessToken , opts ) ;
321+ } else if ( model . pipeline_tag == "text-to-image" ) {
322+ // TODO: factorize this logic
323+ return snippetTextToImage ( model , accessToken ) ;
303324 } else {
304325 let snippets =
305326 model . pipeline_tag && model . pipeline_tag in pythonSnippets
0 commit comments