@@ -22,12 +22,14 @@ import * as path from "node:path/posix";
2222import { snippets } from "@huggingface/inference" ;
2323import type { SnippetInferenceProvider , InferenceSnippet , ModelDataMinimal } from "@huggingface/tasks" ;
2424
25- type LANGUAGE = "sh" | "js" | "py" ;
25+ const LANGUAGES = [ "sh" , "js" , "python" ] as const ;
26+ type Language = ( typeof LANGUAGES ) [ number ] ;
27+ const EXTENSIONS : Record < Language , string > = { sh : "sh" , js : "js" , python : "py" } ;
2628
2729const TEST_CASES : {
2830 testName : string ;
2931 model : ModelDataMinimal ;
30- languages : LANGUAGE [ ] ;
32+ languages : Language [ ] ;
3133 providers : SnippetInferenceProvider [ ] ;
3234 opts ?: Record < string , unknown > ;
3335} [ ] = [
@@ -39,7 +41,7 @@ const TEST_CASES: {
3941 tags : [ ] ,
4042 inference : "" ,
4143 } ,
42- languages : [ "py " ] ,
44+ languages : [ "python " ] ,
4345 providers : [ "hf-inference" ] ,
4446 } ,
4547 {
@@ -50,7 +52,7 @@ const TEST_CASES: {
5052 tags : [ "conversational" ] ,
5153 inference : "" ,
5254 } ,
53- languages : [ "sh" , "js" , "py " ] ,
55+ languages : [ "sh" , "js" , "python " ] ,
5456 providers : [ "hf-inference" , "together" ] ,
5557 opts : { streaming : false } ,
5658 } ,
@@ -62,7 +64,7 @@ const TEST_CASES: {
6264 tags : [ "conversational" ] ,
6365 inference : "" ,
6466 } ,
65- languages : [ "sh" , "js" , "py " ] ,
67+ languages : [ "sh" , "js" , "python " ] ,
6668 providers : [ "hf-inference" , "together" ] ,
6769 opts : { streaming : true } ,
6870 } ,
@@ -74,7 +76,7 @@ const TEST_CASES: {
7476 tags : [ "conversational" ] ,
7577 inference : "" ,
7678 } ,
77- languages : [ "sh" , "js" , "py " ] ,
79+ languages : [ "sh" , "js" , "python " ] ,
7880 providers : [ "hf-inference" , "fireworks-ai" ] ,
7981 opts : { streaming : false } ,
8082 } ,
@@ -86,7 +88,7 @@ const TEST_CASES: {
8688 tags : [ "conversational" ] ,
8789 inference : "" ,
8890 } ,
89- languages : [ "sh" , "js" , "py " ] ,
91+ languages : [ "sh" , "js" , "python " ] ,
9092 providers : [ "hf-inference" , "fireworks-ai" ] ,
9193 opts : { streaming : true } ,
9294 } ,
@@ -98,7 +100,7 @@ const TEST_CASES: {
98100 tags : [ ] ,
99101 inference : "" ,
100102 } ,
101- languages : [ "py " ] ,
103+ languages : [ "python " ] ,
102104 providers : [ "hf-inference" ] ,
103105 } ,
104106 {
@@ -109,7 +111,7 @@ const TEST_CASES: {
109111 tags : [ ] ,
110112 inference : "" ,
111113 } ,
112- languages : [ "py " ] ,
114+ languages : [ "python " ] ,
113115 providers : [ "hf-inference" ] ,
114116 } ,
115117 {
@@ -121,7 +123,7 @@ const TEST_CASES: {
121123 inference : "" ,
122124 } ,
123125 providers : [ "hf-inference" ] ,
124- languages : [ "py " ] ,
126+ languages : [ "python " ] ,
125127 } ,
126128 {
127129 testName : "text-to-audio-transformers" ,
@@ -132,7 +134,7 @@ const TEST_CASES: {
132134 inference : "" ,
133135 } ,
134136 providers : [ "hf-inference" ] ,
135- languages : [ "py " ] ,
137+ languages : [ "python " ] ,
136138 } ,
137139 {
138140 testName : "text-to-image" ,
@@ -143,7 +145,7 @@ const TEST_CASES: {
143145 inference : "" ,
144146 } ,
145147 providers : [ "hf-inference" , "fal-ai" ] ,
146- languages : [ "sh" , "js" , "py " ] ,
148+ languages : [ "sh" , "js" , "python " ] ,
147149 } ,
148150 {
149151 testName : "text-to-video" ,
@@ -154,7 +156,7 @@ const TEST_CASES: {
154156 inference : "" ,
155157 } ,
156158 providers : [ "replicate" , "fal-ai" ] ,
157- languages : [ "js" , "py " ] ,
159+ languages : [ "js" , "python " ] ,
158160 } ,
159161 {
160162 testName : "text-classification" ,
@@ -165,7 +167,7 @@ const TEST_CASES: {
165167 inference : "" ,
166168 } ,
167169 providers : [ "hf-inference" ] ,
168- languages : [ "sh" , "js" , "py " ] ,
170+ languages : [ "sh" , "js" , "python " ] ,
169171 } ,
170172 {
171173 testName : "basic-snippet--token-classification" ,
@@ -176,7 +178,7 @@ const TEST_CASES: {
176178 inference : "" ,
177179 } ,
178180 providers : [ "hf-inference" ] ,
179- languages : [ "py " ] ,
181+ languages : [ "python " ] ,
180182 } ,
181183 {
182184 testName : "zero-shot-classification" ,
@@ -187,7 +189,7 @@ const TEST_CASES: {
187189 inference : "" ,
188190 } ,
189191 providers : [ "hf-inference" ] ,
190- languages : [ "py " ] ,
192+ languages : [ "python " ] ,
191193 } ,
192194 {
193195 testName : "zero-shot-image-classification" ,
@@ -198,14 +200,14 @@ const TEST_CASES: {
198200 inference : "" ,
199201 } ,
200202 providers : [ "hf-inference" ] ,
201- languages : [ "py " ] ,
203+ languages : [ "python " ] ,
202204 } ,
203205] as const ;
204206
205207const GET_SNIPPET_FN = {
206208 sh : snippets . curl . getCurlInferenceSnippet ,
207209 js : snippets . js . getJsInferenceSnippet ,
208- py : snippets . python . getPythonInferenceSnippet ,
210+ python : snippets . python . getPythonInferenceSnippet ,
209211} as const ;
210212
211213const rootDirFinder = ( ) : string => {
@@ -228,42 +230,51 @@ function getFixtureFolder(testName: string): string {
228230
229231function generateInferenceSnippet (
230232 model : ModelDataMinimal ,
231- language : LANGUAGE ,
233+ language : Language ,
232234 provider : SnippetInferenceProvider ,
233235 opts ?: Record < string , unknown >
234236) : InferenceSnippet [ ] {
235237 const providerModelId = provider === "hf-inference" ? model . id : `<${ provider } alias for ${ model . id } >` ;
236- return GET_SNIPPET_FN [ language ] ( model , "api_token" , provider , providerModelId , opts ) ;
238+ const snippets = GET_SNIPPET_FN [ language ] ( model , "api_token" , provider , providerModelId , opts ) as InferenceSnippet [ ] ;
239+ return snippets . sort ( ( snippetA , snippetB ) => snippetA . client . localeCompare ( snippetB . client ) ) ;
237240}
238241
239242async function getExpectedInferenceSnippet (
240243 testName : string ,
241- language : LANGUAGE ,
244+ language : Language ,
242245 provider : SnippetInferenceProvider
243246) : Promise < InferenceSnippet [ ] > {
244247 const fixtureFolder = getFixtureFolder ( testName ) ;
245- const files = await fs . readdir ( fixtureFolder ) ;
248+ const languageFolder = path . join ( fixtureFolder , language ) ;
249+ const files = await fs . readdir ( languageFolder , { recursive : true } ) ;
246250
247251 const expectedSnippets : InferenceSnippet [ ] = [ ] ;
248- for ( const file of files . filter ( ( file ) => file . endsWith ( "." + language ) && file . includes ( `.${ provider } .` ) ) . sort ( ) ) {
249- const client = path . basename ( file ) . split ( "." ) . slice ( 1 , - 2 ) . join ( "." ) ; // e.g. '0.huggingface.js.replicate.js' => "huggingface.js"
250- const content = await fs . readFile ( path . join ( fixtureFolder , file ) , { encoding : "utf-8" } ) ;
252+ for ( const file of files . filter ( ( file ) => file . includes ( `.${ provider } .` ) ) . sort ( ) ) {
253+ const client = file . split ( "/" ) [ 0 ] ; // e.g. fal_client/1.fal-ai.python => fal_client
254+ const content = await fs . readFile ( path . join ( languageFolder , file ) , { encoding : "utf-8" } ) ;
251255 expectedSnippets . push ( { client, content } ) ;
252256 }
253257 return expectedSnippets ;
254258}
255259
256260async function saveExpectedInferenceSnippet (
257261 testName : string ,
258- language : LANGUAGE ,
262+ language : Language ,
259263 provider : SnippetInferenceProvider ,
260264 snippets : InferenceSnippet [ ]
261265) {
262266 const fixtureFolder = getFixtureFolder ( testName ) ;
263267 await fs . mkdir ( fixtureFolder , { recursive : true } ) ;
264268
265- for ( const [ index , snippet ] of snippets . entries ( ) ) {
266- const file = path . join ( fixtureFolder , `${ index } .${ snippet . client ?? "default" } .${ provider } .${ language } ` ) ;
269+ const indexPerClient = new Map < string , number > ( ) ;
270+ for ( const snippet of snippets ) {
271+ const extension = EXTENSIONS [ language ] ;
272+ const client = snippet . client ;
273+ const index = indexPerClient . get ( client ) ?? 0 ;
274+ indexPerClient . set ( client , index + 1 ) ;
275+
276+ const file = path . join ( fixtureFolder , language , snippet . client , `${ index } .${ provider } .${ extension } ` ) ;
277+ await fs . mkdir ( path . dirname ( file ) , { recursive : true } ) ;
267278 await fs . writeFile ( file , snippet . content ) ;
268279 }
269280}
0 commit comments