@@ -19,7 +19,7 @@ import { existsSync as pathExists } from "node:fs";
1919import * as fs from "node:fs/promises" ;
2020import * as path from "node:path/posix" ;
2121
22- import type { InferenceSnippet } from "@huggingface/tasks" ;
22+ import type { InferenceProvider , InferenceSnippet } from "@huggingface/tasks" ;
2323import { snippets } from "@huggingface/tasks" ;
2424
2525type LANGUAGE = "sh" | "js" | "py" ;
@@ -28,6 +28,7 @@ const TEST_CASES: {
2828 testName : string ;
2929 model : snippets . ModelDataMinimal ;
3030 languages : LANGUAGE [ ] ;
31+ providers : InferenceProvider [ ] ;
3132 opts ?: Record < string , unknown > ;
3233} [ ] = [
3334 {
@@ -39,6 +40,7 @@ const TEST_CASES: {
3940 inference : "" ,
4041 } ,
4142 languages : [ "sh" , "js" , "py" ] ,
43+ providers : [ "hf-inference" , "together" ] ,
4244 opts : { streaming : false } ,
4345 } ,
4446 {
@@ -50,6 +52,7 @@ const TEST_CASES: {
5052 inference : "" ,
5153 } ,
5254 languages : [ "sh" , "js" , "py" ] ,
55+ providers : [ "hf-inference" ] ,
5356 opts : { streaming : true } ,
5457 } ,
5558 {
@@ -61,6 +64,7 @@ const TEST_CASES: {
6164 inference : "" ,
6265 } ,
6366 languages : [ "sh" , "js" , "py" ] ,
67+ providers : [ "hf-inference" ] ,
6468 opts : { streaming : false } ,
6569 } ,
6670 {
@@ -72,6 +76,7 @@ const TEST_CASES: {
7276 inference : "" ,
7377 } ,
7478 languages : [ "sh" , "js" , "py" ] ,
79+ providers : [ "hf-inference" ] ,
7580 opts : { streaming : true } ,
7681 } ,
7782 {
@@ -82,6 +87,7 @@ const TEST_CASES: {
8287 tags : [ ] ,
8388 inference : "" ,
8489 } ,
90+ providers : [ "hf-inference" ] ,
8591 languages : [ "sh" , "js" , "py" ] ,
8692 } ,
8793] as const ;
@@ -113,31 +119,41 @@ function getFixtureFolder(testName: string): string {
113119function generateInferenceSnippet (
114120 model : snippets . ModelDataMinimal ,
115121 language : LANGUAGE ,
122+ provider : InferenceProvider ,
116123 opts ?: Record < string , unknown >
117124) : InferenceSnippet [ ] {
118- const generatedSnippets = GET_SNIPPET_FN [ language ] ( model , "api_token" , opts ) ;
125+ const generatedSnippets = GET_SNIPPET_FN [ language ] ( model , "api_token" , provider , opts ) ;
119126 return Array . isArray ( generatedSnippets ) ? generatedSnippets : [ generatedSnippets ] ;
120127}
121128
122- async function getExpectedInferenceSnippet ( testName : string , language : LANGUAGE ) : Promise < InferenceSnippet [ ] > {
129+ async function getExpectedInferenceSnippet (
130+ testName : string ,
131+ language : LANGUAGE ,
132+ provider : InferenceProvider
133+ ) : Promise < InferenceSnippet [ ] > {
123134 const fixtureFolder = getFixtureFolder ( testName ) ;
124135 const files = await fs . readdir ( fixtureFolder ) ;
125136
126137 const expectedSnippets : InferenceSnippet [ ] = [ ] ;
127- for ( const file of files . filter ( ( file ) => file . endsWith ( "." + language ) ) . sort ( ) ) {
128- const client = path . basename ( file ) . split ( "." ) . slice ( 1 , - 1 ) . join ( "." ) ; // e.g. '0.huggingface.js.js' => "huggingface.js"
138+ for ( const file of files . filter ( ( file ) => file . endsWith ( "." + language ) && file . includes ( `. ${ provider } .` ) ) . sort ( ) ) {
139+ const client = path . basename ( file ) . split ( "." ) . slice ( 1 , - 2 ) . join ( "." ) ; // e.g. '0.huggingface.js.replicate .js' => "huggingface.js"
129140 const content = await fs . readFile ( path . join ( fixtureFolder , file ) , { encoding : "utf-8" } ) ;
130- expectedSnippets . push ( client === "default" ? { content } : { client, content } ) ;
141+ expectedSnippets . push ( { client, content } ) ;
131142 }
132143 return expectedSnippets ;
133144}
134145
135- async function saveExpectedInferenceSnippet ( testName : string , language : LANGUAGE , snippets : InferenceSnippet [ ] ) {
146+ async function saveExpectedInferenceSnippet (
147+ testName : string ,
148+ language : LANGUAGE ,
149+ provider : InferenceProvider ,
150+ snippets : InferenceSnippet [ ]
151+ ) {
136152 const fixtureFolder = getFixtureFolder ( testName ) ;
137153 await fs . mkdir ( fixtureFolder , { recursive : true } ) ;
138154
139155 for ( const [ index , snippet ] of snippets . entries ( ) ) {
140- const file = path . join ( fixtureFolder , `${ index } .${ snippet . client ?? "default" } .${ language } ` ) ;
156+ const file = path . join ( fixtureFolder , `${ index } .${ snippet . client ?? "default" } .${ provider } . ${ language } ` ) ;
141157 await fs . writeFile ( file , snippet . content ) ;
142158 }
143159}
@@ -147,13 +163,15 @@ if (import.meta.vitest) {
147163 const { describe, expect, it } = import . meta. vitest ;
148164
149165 describe ( "inference API snippets" , ( ) => {
150- TEST_CASES . forEach ( ( { testName, model, languages, opts } ) => {
166+ TEST_CASES . forEach ( ( { testName, model, languages, providers , opts } ) => {
151167 describe ( testName , ( ) => {
152168 languages . forEach ( ( language ) => {
153- it ( language , async ( ) => {
154- const generatedSnippets = generateInferenceSnippet ( model , language , opts ) ;
155- const expectedSnippets = await getExpectedInferenceSnippet ( testName , language ) ;
156- expect ( generatedSnippets ) . toEqual ( expectedSnippets ) ;
169+ providers . forEach ( ( provider ) => {
170+ it ( language , async ( ) => {
171+ const generatedSnippets = generateInferenceSnippet ( model , language , provider , opts ) ;
172+ const expectedSnippets = await getExpectedInferenceSnippet ( testName , language , provider ) ;
173+ expect ( generatedSnippets ) . toEqual ( expectedSnippets ) ;
174+ } ) ;
157175 } ) ;
158176 } ) ;
159177 } ) ;
@@ -166,11 +184,13 @@ if (import.meta.vitest) {
166184 await fs . rm ( path . join ( rootDirFinder ( ) , "snippets-fixtures" ) , { recursive : true , force : true } ) ;
167185
168186 console . debug ( " 🏭 Generating new fixtures..." ) ;
169- TEST_CASES . forEach ( ( { testName, model, languages, opts } ) => {
170- console . debug ( ` ${ testName } (${ languages . join ( ", " ) } )` ) ;
187+ TEST_CASES . forEach ( ( { testName, model, languages, providers , opts } ) => {
188+ console . debug ( ` ${ testName } (${ languages . join ( ", " ) } ) ( ${ providers . join ( ", " ) } ) ` ) ;
171189 languages . forEach ( async ( language ) => {
172- const generatedSnippets = generateInferenceSnippet ( model , language , opts ) ;
173- await saveExpectedInferenceSnippet ( testName , language , generatedSnippets ) ;
190+ providers . forEach ( async ( provider ) => {
191+ const generatedSnippets = generateInferenceSnippet ( model , language , provider , opts ) ;
192+ await saveExpectedInferenceSnippet ( testName , language , provider , generatedSnippets ) ;
193+ } ) ;
174194 } ) ;
175195 } ) ;
176196 console . log ( "✅ All done!" ) ;
0 commit comments