Skip to content

Commit 5629b86

Browse files
committed
Together.ai implem
1 parent a413824 commit 5629b86

File tree

5 files changed

+97
-3
lines changed

5 files changed

+97
-3
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
2+
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
23
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
34
import { omit } from "../utils/omit";
45
import { HF_HUB_URL } from "./getDefaultTask";
@@ -66,6 +67,9 @@ export async function makeRequestOptions(
6667
case "sambanova":
6768
model = SAMBANOVA_MODEL_IDS[model];
6869
break;
70+
case "together":
71+
model = TOGETHER_MODEL_IDS[model]?.id ?? model;
72+
break;
6973
default:
7074
break;
7175
}
@@ -113,6 +117,8 @@ export async function makeRequestOptions(
113117
switch (provider) {
114118
case "sambanova":
115119
return SAMBANOVA_API_BASE_URL;
120+
case "together":
121+
return TOGETHER_API_BASE_URL;
116122
default:
117123
break;
118124
}

packages/inference/src/providers/sambanova.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
88
* or keep it up-to-date.
99
*
1010
* As a fallback, if the above is not possible, ask Sambanova to
11-
* provide the mapping as an API.
11+
* provide the mapping as an fetchable API.
1212
*/
1313
type SambanovaId = string;
1414

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import type { ModelId } from "../types";
2+
3+
export const TOGETHER_API_BASE_URL = "https://api.together.xyz";
4+
5+
/**
6+
* Same comment as in sambanova.ts
7+
*/
8+
type TogetherId = string;
9+
10+
/**
11+
* https://docs.together.ai/reference/models-1
12+
*/
13+
export const TOGETHER_MODEL_IDS: Record<
14+
ModelId,
15+
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
16+
> = {
17+
"BAAI/bge-base-en-v1.5": { id: "BAAI/bge-base-en-v1.5", type: "embedding" },
18+
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
19+
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
20+
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
21+
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
22+
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
23+
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
24+
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
25+
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
26+
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
27+
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
28+
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
29+
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
30+
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
31+
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
32+
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
33+
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
34+
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
35+
"meta-llama/Llama-Guard-3-11B-Vision": { id: "meta-llama/Llama-Guard-3-11B-Vision-Turbo", type: "moderation" },
36+
"meta-llama/LlamaGuard-7b": { id: "Meta-Llama/Llama-Guard-7b", type: "moderation" },
37+
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
38+
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
39+
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
40+
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
41+
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
42+
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
43+
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
44+
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
45+
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
46+
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
47+
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
48+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
49+
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
50+
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
51+
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
52+
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
53+
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
54+
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
55+
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
56+
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },
57+
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },
58+
"togethercomputer/m2-bert-80M-32k-retrieval": { id: "togethercomputer/m2-bert-80M-32k-retrieval", type: "embedding" },
59+
"togethercomputer/m2-bert-80M-8k-retrieval": { id: "togethercomputer/m2-bert-80M-8k-retrieval", type: "embedding" },
60+
};

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ export async function chatCompletion(
2222
typeof res?.created === "number" &&
2323
typeof res?.id === "string" &&
2424
typeof res?.model === "string" &&
25-
typeof res?.system_fingerprint === "string" &&
25+
/// Together.ai does not output a system_fingerprint
26+
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
2627
typeof res?.usage === "object";
2728

2829
if (!isValidOutput) {

packages/inference/test/HfInference.spec.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,34 @@ describe.concurrent(
786786
out += chunk.choices[0].delta.content;
787787
}
788788
}
789-
console.warn(out);
789+
expect(out).toContain("2");
790+
});
791+
792+
it("chatCompletion together", async () => {
793+
const hf = new HfInference(env.TOGETHER_KEY);
794+
const res = await hf.chatCompletion({
795+
model: "meta-llama/Llama-3.3-70B-Instruct",
796+
provider: "together",
797+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
798+
});
799+
if (res.choices && res.choices.length > 0) {
800+
const completion = res.choices[0].message?.content;
801+
expect(completion).toContain("two");
802+
}
803+
});
804+
it("chatCompletion together stream", async () => {
805+
const hf = new HfInference(env.TOGETHER_KEY);
806+
const stream = hf.chatCompletionStream({
807+
model: "meta-llama/Llama-3.3-70B-Instruct",
808+
provider: "together",
809+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
810+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
811+
let out = "";
812+
for await (const chunk of stream) {
813+
if (chunk.choices && chunk.choices.length > 0) {
814+
out += chunk.choices[0].delta.content;
815+
}
816+
}
790817
expect(out).toContain("2");
791818
});
792819
},

0 commit comments

Comments
 (0)