Skip to content

Commit 737b844

Browse files
fabistbyaron2
andauthored
Enhance OpenAI to be Azure OpenAi compatible (#3918)
Signed-off-by: fabistb <[email protected]> Co-authored-by: Yaron Schneider <[email protected]>
1 parent 78efc97 commit 737b844

File tree

12 files changed

+203
-7
lines changed

12 files changed

+203
-7
lines changed

conversation/openai/metadata.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
16+
package openai
17+
18+
import "github.com/dapr/components-contrib/conversation"
19+
20+
// OpenAILangchainMetadata extends LangchainMetadata with OpenAI-specific properties.
21+
type OpenAILangchainMetadata struct {
22+
conversation.LangchainMetadata `json:",inline" mapstructure:",squash"`
23+
APIType string `json:"apiType" mapstructure:"apiType"`
24+
APIVersion string `json:"apiVersion" mapstructure:"apiVersion"`
25+
}

conversation/openai/metadata.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,17 @@ metadata:
4040
A time-to-live value for a prompt cache to expire. Uses Golang durations
4141
type: string
4242
example: '10m'
43+
- name: apiVersion
44+
required: false
45+
description: |
46+
The API version to use for the Azure OpenAI service. This is required when using Azure OpenAI.
47+
type: string
48+
example: '2025-01-01-preview'
49+
default: ''
50+
- name: apiType
51+
required: false
52+
description: |
53+
The type of API to use for the OpenAI service. This is required when using Azure OpenAI.
54+
type: string
55+
example: 'azure'
56+
default: ''
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package openai
15+
16+
import (
17+
"encoding/json"
18+
"testing"
19+
20+
"github.com/stretchr/testify/assert"
21+
"github.com/stretchr/testify/require"
22+
23+
"github.com/dapr/components-contrib/conversation"
24+
)
25+
26+
func TestOpenaiLangchainMetadata(t *testing.T) {
27+
t.Run("json marshaling with endpoint", func(t *testing.T) {
28+
metadata := OpenAILangchainMetadata{
29+
LangchainMetadata: conversation.LangchainMetadata{
30+
Key: "test-key",
31+
Model: "gpt-4",
32+
CacheTTL: "10m",
33+
Endpoint: "https://custom-endpoint.openai.azure.com/",
34+
},
35+
APIType: "azure",
36+
APIVersion: "2025-01-01-preview",
37+
}
38+
39+
bytes, err := json.Marshal(metadata)
40+
require.NoError(t, err)
41+
42+
var unmarshaled OpenAILangchainMetadata
43+
err = json.Unmarshal(bytes, &unmarshaled)
44+
require.NoError(t, err)
45+
46+
assert.Equal(t, metadata.Key, unmarshaled.Key)
47+
assert.Equal(t, metadata.Model, unmarshaled.Model)
48+
assert.Equal(t, metadata.CacheTTL, unmarshaled.CacheTTL)
49+
assert.Equal(t, metadata.Endpoint, unmarshaled.Endpoint)
50+
assert.Equal(t, metadata.APIType, unmarshaled.APIType)
51+
assert.Equal(t, metadata.APIVersion, unmarshaled.APIVersion)
52+
})
53+
54+
t.Run("json unmarshaling with endpoint", func(t *testing.T) {
55+
jsonStr := `{"key": "test-key", "model": "gpt-4", "endpoint": "https://custom-endpoint.openai.azure.com/", "apiType": "azure", "apiVersion": "2025-01-01-preview"}`
56+
57+
var metadata OpenAILangchainMetadata
58+
err := json.Unmarshal([]byte(jsonStr), &metadata)
59+
require.NoError(t, err)
60+
61+
assert.Equal(t, "test-key", metadata.Key)
62+
assert.Equal(t, "gpt-4", metadata.Model)
63+
assert.Equal(t, "https://custom-endpoint.openai.azure.com/", metadata.Endpoint)
64+
assert.Equal(t, "azure", metadata.APIType)
65+
assert.Equal(t, "2025-01-01-preview", metadata.APIVersion)
66+
})
67+
}

conversation/openai/openai.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package openai
1616

1717
import (
1818
"context"
19+
"errors"
1920
"reflect"
2021

2122
"github.com/dapr/components-contrib/conversation"
@@ -44,7 +45,7 @@ func NewOpenAI(logger logger.Logger) conversation.Conversation {
4445
const defaultModel = "gpt-4o"
4546

4647
func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
47-
md := conversation.LangchainMetadata{}
48+
md := OpenAILangchainMetadata{}
4849
err := kmeta.DecodeMetadata(meta.Properties, &md)
4950
if err != nil {
5051
return err
@@ -65,6 +66,14 @@ func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
6566
options = append(options, openai.WithBaseURL(md.Endpoint))
6667
}
6768

69+
if md.APIType == "azure" {
70+
if md.Endpoint == "" || md.APIVersion == "" {
71+
return errors.New("endpoint and apiVersion must be provided when apiType is set to 'azure'")
72+
}
73+
74+
options = append(options, openai.WithAPIType(openai.APITypeAzure), openai.WithAPIVersion(md.APIVersion))
75+
}
76+
6877
llm, err := openai.New(options...)
6978
if err != nil {
7079
return err
@@ -84,7 +93,7 @@ func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
8493
}
8594

8695
func (o *OpenAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
87-
metadataStruct := conversation.LangchainMetadata{}
96+
metadataStruct := OpenAILangchainMetadata{}
8897
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
8998
return
9099
}

conversation/openai/openai_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,47 @@ func TestInit(t *testing.T) {
5555
// we're mainly testing that initialization succeeds
5656
},
5757
},
58+
{
59+
name: "with apiType azure and missing apiVersion",
60+
metadata: map[string]string{
61+
"key": "test-key",
62+
"model": "gpt-4",
63+
"apiType": "azure",
64+
"endpoint": "https://custom-endpoint.openai.azure.com/",
65+
},
66+
testFn: func(t *testing.T, o *OpenAI, err error) {
67+
require.Error(t, err)
68+
assert.EqualError(t, err, "endpoint and apiVersion must be provided when apiType is set to 'azure'")
69+
},
70+
},
71+
{
72+
name: "with apiType azure and custom apiVersion",
73+
metadata: map[string]string{
74+
"key": "test-key",
75+
"model": "gpt-4",
76+
"apiType": "azure",
77+
"endpoint": "https://custom-endpoint.openai.azure.com/",
78+
"apiVersion": "2025-01-01-preview",
79+
},
80+
testFn: func(t *testing.T, o *OpenAI, err error) {
81+
require.NoError(t, err)
82+
assert.NotNil(t, o.LLM)
83+
},
84+
},
85+
{
86+
name: "with apiType azure but missing endpoint",
87+
metadata: map[string]string{
88+
"key": "test-key",
89+
"model": "gpt-4",
90+
"apiType": "azure",
91+
"apiVersion": "2025-01-01-preview",
92+
},
93+
testFn: func(t *testing.T, o *OpenAI, err error) {
94+
require.Error(t, err)
95+
assert.EqualError(t, err, "endpoint and apiVersion must be provided when apiType is set to 'azure'")
96+
},
97+
},
5898
}
59-
6099
for _, tc := range testCases {
61100
t.Run(tc.name, func(t *testing.T) {
62101
o := NewOpenAI(logger.NewLogger("openai test"))

tests/config/conversation/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This directory contains conformance tests for all conversation components, inclu
55
## Available Components
66

77
- **echo** - Simple echo component for testing (no configuration needed)
8-
- **openai** - OpenAI GPT models
8+
- **openai** - OpenAI GPT models (also supports Azure OpenAI)
99
- **anthropic** - Anthropic Claude models
1010
- **googleai** - Google Gemini models
1111
- **mistral** - Mistral AI models
@@ -52,6 +52,14 @@ export OPENAI_API_KEY="your_openai_api_key"
5252
```
5353
Get your API key from: https://platform.openai.com/api-keys
5454

55+
### Azure OpenAI
56+
```bash
57+
export AZURE_OPENAI_API_KEY="your_openai_api_key"
58+
export AZURE_OPENAI_ENDPOINT="your_azureopenai_endpoint_here"
59+
export AZURE_OPENAI_API_VERSION="your_azreopenai_api_version_here"
60+
```
61+
Get your configuration values from: https://ai.azure.com/
62+
5563
### Anthropic
5664
```bash
5765
export ANTHROPIC_API_KEY="your_anthropic_api_key"
@@ -142,4 +150,5 @@ This approach provides better reliability and compatibility while maintaining ac
142150
- Cost-effective models are used by default to minimize API costs
143151
- HuggingFace uses the OpenAI compatibility layer as a workaround due to langchaingo API issues
144152
- Ollama requires a local server and must be explicitly enabled
153+
- OpenAI component is tested for OpenAI and Azure
145154
- All tests include proper initialization and basic conversation functionality testing

tests/config/conversation/env.template

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
# OpenAI API Key - Get from https://platform.openai.com/api-keys
55
OPENAI_API_KEY=your_openai_api_key_here
66

7+
# Azure OpenAI - Get from https://ai.azure.com/
8+
AZURE_OPENAI_API_KEY=your_azureopenai_api_key_here
9+
AZURE_OPENAI_ENDPOINT=your_azureopenai_endpoint_here
10+
AZURE_OPENAI_API_VERSION=your_azreopenai_api_version_here
11+
712
# Anthropic API Key - Get from https://console.anthropic.com/
813
ANTHROPIC_API_KEY=your_anthropic_api_key_here
914

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
apiVersion: dapr.io/v1alpha1
2+
kind: Component
3+
metadata:
4+
name: openai
5+
spec:
6+
type: conversation.openai
7+
version: v1
8+
metadata:
9+
- name: key
10+
value: "${{AZURE_OPENAI_API_KEY}}"
11+
- name: model
12+
value: "gpt-4o-mini"
13+
- name: endpoint
14+
value: "${{AZURE_OPENAI_ENDPOINT}}"
15+
- name: apiType
16+
value: "azure"
17+
- name: apiVersion
18+
value: "${{AZURE_OPENAI_API_VERSION}}"

tests/config/conversation/test_conformance.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ echo " ./test_conformance.sh"
2121
echo ""
2222
echo "Option 2: Set environment variables directly:"
2323
echo " export OPENAI_API_KEY=\"your_openai_api_key\""
24+
echo " export AZURE_OPENAI_API_KEY=\"your_azureopenai_api_key\""
25+
echo " export AZURE_OPENAI_ENDPOINT=\"your_azureopenai_endpoint\""
26+
echo " export AZURE_OPENAI_API_VERSION=\"your_azureopenai_api_version\""
2427
echo " export ANTHROPIC_API_KEY=\"your_anthropic_api_key\""
2528
echo " export GOOGLE_AI_API_KEY=\"your_google_ai_api_key\""
2629
echo " export MISTRAL_API_KEY=\"your_mistral_api_key\""

0 commit comments

Comments
 (0)