Skip to content

Commit 1a70750

Browse files
authored
Merge pull request #1 from felixtrz/provider/llama
Add Provider support for the new Llama API platform
2 parents 792c171 + 336852f commit 1a70750

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

packages/prai/src/provider/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ export * from './mock.js'
22
export * from './openai.js'
33
export * from './groq.js'
44
export * from './gemini.js'
5+
export * from './llama.js'
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import { Price, Provider } from '../model.js'
2+
import { Schema, ZodObject, ZodString, ZodUnion } from 'zod'
3+
4+
import { Message } from '../step.js'
5+
import { buildJsonSchema } from '../schema/json.js'
6+
7+
interface LlamaOptions {
8+
apiKey: string
9+
}
10+
11+
interface LlamaMessage {
12+
role: string
13+
content: string
14+
}
15+
16+
interface LlamaResponse {
17+
completion_message: {
18+
content: {
19+
text: string
20+
}
21+
}
22+
}
23+
24+
// Simplified JSON schema types for Llama API (which doesn't support all prai JsonSchema features)
25+
interface LlamaJsonSchema {
26+
type: 'object' | 'array' | 'string' | 'number' | 'boolean' | 'integer' | 'null'
27+
description?: string
28+
additionalProperties?: boolean
29+
properties?: Record<string, LlamaJsonSchema>
30+
required?: string[]
31+
items?: LlamaJsonSchema
32+
enum?: string[]
33+
}
34+
35+
interface LlamaRequestBody {
36+
model: string
37+
messages: LlamaMessage[]
38+
response_format?: {
39+
type: 'json_schema'
40+
json_schema: {
41+
name: string
42+
schema: LlamaJsonSchema
43+
}
44+
}
45+
}
46+
47+
async function llamaQuery(
48+
model: string,
49+
messages: Array<LlamaMessage>,
50+
apiKey: string,
51+
abortSignal: AbortSignal | undefined,
52+
responseFormat?: LlamaRequestBody['response_format'],
53+
): Promise<string> {
54+
const requestBody: LlamaRequestBody = {
55+
model,
56+
messages,
57+
...(responseFormat && { response_format: responseFormat }),
58+
}
59+
60+
const response = await fetch('https://api.llama.com/v1/chat/completions', {
61+
method: 'POST',
62+
headers: {
63+
'Content-Type': 'application/json',
64+
Authorization: `Bearer ${apiKey}`,
65+
},
66+
body: JSON.stringify(requestBody),
67+
signal: abortSignal,
68+
})
69+
70+
if (!response.ok) {
71+
const errorText = await response.text()
72+
throw new Error(`HTTP ${response.status}: ${response.statusText} - ${errorText}`)
73+
}
74+
75+
const data: LlamaResponse = await response.json()
76+
return data.completion_message.content.text
77+
}
78+
79+
function buildLlamaResponseFormat(
80+
schema: Schema,
81+
wrapInObject: boolean,
82+
): LlamaRequestBody['response_format'] | undefined {
83+
if (schema instanceof ZodString) {
84+
return undefined
85+
}
86+
87+
let praiSchema = buildJsonSchema(schema)
88+
89+
// Convert prai JsonSchema to LlamaJsonSchema by removing unsupported properties
90+
const convertToLlamaSchema = (obj: any): LlamaJsonSchema => {
91+
if (typeof obj !== 'object' || obj === null) return obj
92+
93+
const result: LlamaJsonSchema = {
94+
type: obj.type,
95+
}
96+
97+
if (obj.description) result.description = obj.description
98+
if (obj.additionalProperties !== undefined) result.additionalProperties = obj.additionalProperties
99+
if (obj.enum) result.enum = obj.enum
100+
if (obj.required) result.required = obj.required
101+
102+
if (obj.properties) {
103+
result.properties = {}
104+
for (const [key, value] of Object.entries(obj.properties)) {
105+
result.properties[key] = convertToLlamaSchema(value)
106+
}
107+
}
108+
109+
if (obj.items) {
110+
result.items = convertToLlamaSchema(obj.items)
111+
}
112+
113+
return result
114+
}
115+
116+
let llamaSchema = convertToLlamaSchema(praiSchema)
117+
118+
if (wrapInObject) {
119+
llamaSchema = {
120+
type: 'object',
121+
additionalProperties: false,
122+
properties: {
123+
result: llamaSchema,
124+
},
125+
required: ['result'],
126+
}
127+
}
128+
129+
return {
130+
type: 'json_schema',
131+
json_schema: {
132+
name: 'response_schema',
133+
schema: llamaSchema,
134+
},
135+
}
136+
}
137+
138+
export function llama(options: LlamaOptions): Provider {
139+
return {
140+
async query(
141+
modelName: string,
142+
modelPrice: Price | undefined,
143+
modelOptions: {},
144+
messages: Array<Message>,
145+
schema: Schema,
146+
abortSignal: AbortSignal | undefined,
147+
): Promise<{ content: string; cost?: number }> {
148+
const transformedMessages = transformMessages(messages)
149+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
150+
const responseFormat = buildLlamaResponseFormat(schema, true)
151+
const result = await llamaQuery(modelName, transformedMessages, options.apiKey, abortSignal, responseFormat)
152+
const { result: parsedResult } = JSON.parse(result)
153+
return { content: JSON.stringify(parsedResult) }
154+
}
155+
const responseFormat = buildLlamaResponseFormat(schema, false)
156+
const result = await llamaQuery(modelName, transformedMessages, options.apiKey, abortSignal, responseFormat)
157+
return { content: result }
158+
},
159+
async *streamingQuery(
160+
modelName: string,
161+
modelPrice: Price | undefined,
162+
modelOptions: {},
163+
messages: Array<Message>,
164+
schema: Schema,
165+
abortSignal: AbortSignal | undefined,
166+
): AsyncIterable<{ content: string; cost?: number }> {
167+
const transformedMessages = transformMessages(messages)
168+
if (!(schema instanceof ZodObject || schema instanceof ZodUnion || schema instanceof ZodString)) {
169+
const responseFormat = buildLlamaResponseFormat(schema, true)
170+
const result = await llamaQuery(modelName, transformedMessages, options.apiKey, abortSignal, responseFormat)
171+
const { result: parsedResult } = JSON.parse(result)
172+
yield { content: JSON.stringify(parsedResult) }
173+
} else {
174+
const responseFormat = buildLlamaResponseFormat(schema, false)
175+
const result = await llamaQuery(modelName, transformedMessages, options.apiKey, abortSignal, responseFormat)
176+
yield { content: result }
177+
}
178+
},
179+
}
180+
}
181+
182+
function transformMessages(messages: Array<Message>): Array<LlamaMessage> {
183+
return messages.map((message) => {
184+
if (message.role === 'user') {
185+
const content = message.content
186+
.map((item) => {
187+
if (item.type === 'text') {
188+
return item.text
189+
}
190+
return ''
191+
})
192+
.join('\n\n')
193+
return {
194+
role: message.role,
195+
content,
196+
}
197+
}
198+
return {
199+
role: message.role,
200+
content: message.content.map(({ text }) => text).join('\n\n'),
201+
}
202+
})
203+
}

0 commit comments

Comments
 (0)