Skip to content

Commit 9034b7d

Browse files
committed
New way of injecting mapping into tests + record VCR
1 parent 7d19138 commit 9034b7d

File tree

4 files changed

+86
-6
lines changed

4 files changed

+86
-6
lines changed

packages/inference/src/lib/getProviderModelId.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ export async function getProviderModelId(
3030
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
3131

3232
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
33-
if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
34-
return HARDCODED_MODEL_ID_MAPPING[params.model];
33+
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
34+
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
3535
}
3636

3737
let inferenceProviderMapping: InferenceProviderMapping | null;
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
import type { ModelId } from "../types";
1+
import type { InferenceProvider } from "../types";
2+
import { type ModelId } from "../types";
23

34
type ProviderId = string;
4-
55
/**
66
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
77
* for a given Inference Provider,
88
* you can add it to the following dictionary, for dev purposes.
9+
*
10+
* We also inject into this dictionary from tests.
911
*/
10-
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
12+
export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
1113
/**
1214
* "HF model ID" => "Model ID on Inference Provider's side"
15+
*
16+
* Example:
17+
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1318
*/
14-
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
19+
"fal-ai": {},
20+
"fireworks-ai": {},
21+
"hf-inference": {},
22+
replicate: {},
23+
sambanova: {},
24+
together: {},
1525
};

packages/inference/test/HfInference.spec.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { chatCompletion, HfInference } from "../src";
66
import { textToVideo } from "../src/tasks/cv/textToVideo";
77
import { readTestFile } from "./test-files";
88
import "./vcr";
9+
import { HARDCODED_MODEL_ID_MAPPING } from "../src/providers/consts";
910

1011
const TIMEOUT = 60000 * 3;
1112
const env = import.meta.env;
@@ -1083,6 +1084,10 @@ describe.concurrent("HfInference", () => {
10831084
() => {
10841085
const client = new HfInference(env.HF_FIREWORKS_KEY);
10851086

1087+
HARDCODED_MODEL_ID_MAPPING["fireworks-ai"] = {
1088+
"deepseek-ai/DeepSeek-R1": "accounts/fireworks/models/deepseek-r1",
1089+
};
1090+
10861091
it("chatCompletion", async () => {
10871092
const res = await client.chatCompletion({
10881093
model: "deepseek-ai/DeepSeek-R1",

packages/inference/test/tapes.json

Lines changed: 65 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)