Skip to content

Commit bdb879e

Browse files
authored
OpenAI supports embeddings, and storage supports streams and overwrite (Azure#46890)
1 parent c41b9eb commit bdb879e

File tree

6 files changed

+133
-31
lines changed

6 files changed

+133
-31
lines changed

sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
public partial class AiModel
2+
{
3+
public AiModel(string model, string modelVersion) { }
4+
public string Model { get { throw null; } }
5+
public string ModelVersion { get { throw null; } }
6+
}
17
namespace Azure.CloudMachine
28
{
39
public partial class CloudMachineClient : Azure.CloudMachine.CloudMachineWorkspace
@@ -51,10 +57,11 @@ public readonly partial struct StorageServices
5157
private readonly int _dummyPrimitive;
5258
public void DeleteBlob(string path) { }
5359
public System.BinaryData DownloadBlob(string path) { throw null; }
54-
public string UploadBytes(System.BinaryData bytes, string? name = null) { throw null; }
55-
public string UploadBytes(byte[] bytes, string? name = null) { throw null; }
56-
public string UploadBytes(System.ReadOnlyMemory<byte> bytes, string? name = null) { throw null; }
57-
public string UploadJson(object json, string? name = null) { throw null; }
60+
public string UploadBinaryData(System.BinaryData data, string? name = null, bool overwrite = false) { throw null; }
61+
public string UploadBytes(byte[] bytes, string? name = null, bool overwrite = false) { throw null; }
62+
public string UploadBytes(System.ReadOnlyMemory<byte> bytes, string? name = null, bool overwrite = false) { throw null; }
63+
public string UploadJson(object json, string? name = null, bool overwrite = false) { throw null; }
64+
public string UploadStream(System.IO.Stream fileStream, string? name = null, bool overwrite = false) { throw null; }
5865
public void WhenBlobUploaded(System.Action<Azure.CloudMachine.StorageFile> function) { }
5966
}
6067
}
@@ -132,12 +139,11 @@ namespace Azure.Provisioning.CloudMachine.OpenAI
132139
public static partial class AzureOpenAIExtensions
133140
{
134141
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
142+
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
135143
}
136144
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
137145
{
138-
public OpenAIFeature(string model, string modelVersion) { }
139-
public string Model { get { throw null; } }
140-
public string ModelVersion { get { throw null; } }
146+
public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = null) { }
141147
public override void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure cloudMachine) { }
142148
}
143149
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
public class AiModel
5+
{
6+
public AiModel(string model, string modelVersion) { Model = model; ModelVersion = modelVersion; }
7+
public string Model { get; }
8+
public string ModelVersion { get; }
9+
}

sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,24 @@
88
using Azure.Provisioning.Authorization;
99
using Azure.Provisioning.CognitiveServices;
1010
using OpenAI.Chat;
11+
using OpenAI.Embeddings;
1112

1213
namespace Azure.Provisioning.CloudMachine.OpenAI;
1314

1415
public class OpenAIFeature : CloudMachineFeature
1516
{
16-
public string Model { get; }
17-
public string ModelVersion { get; }
17+
private AiModel _chatDeployment;
18+
private AiModel? _embeddingsDeployment;
1819

19-
public OpenAIFeature(string model, string modelVersion) { Model = model; ModelVersion = modelVersion; }
20+
public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = default)
21+
{
22+
if (chatDeployment == null)
23+
{
24+
throw new ArgumentNullException(nameof(chatDeployment));
25+
}
26+
_chatDeployment = chatDeployment;
27+
_embeddingsDeployment = embeddingsDeployment;
28+
}
2029

2130
public override void AddTo(CloudMachineInfrastructure cloudMachine)
2231
{
@@ -31,31 +40,52 @@ public override void AddTo(CloudMachineInfrastructure cloudMachine)
3140
CustomSubDomainName = cloudMachine.Id
3241
},
3342
};
43+
cloudMachine.AddResource(cognitiveServices);
3444

3545
cloudMachine.AddResource(cognitiveServices.CreateRoleAssignment(
3646
CognitiveServicesBuiltInRole.CognitiveServicesOpenAIContributor,
3747
RoleManagementPrincipalType.User,
3848
cloudMachine.PrincipalIdParameter)
3949
);
4050

41-
// TODO: if we every support more than one deployment, they need to be chained using DependsOn.
42-
// The reason is that deployments need to be deployed/created serially.
43-
CognitiveServicesAccountDeployment deployment = new("openai_deployment", "2023-05-01")
51+
CognitiveServicesAccountDeployment chat = new("openai_deployment", "2023-05-01")
4452
{
4553
Parent = cognitiveServices,
4654
Name = cloudMachine.Id,
4755
Properties = new CognitiveServicesAccountDeploymentProperties()
4856
{
49-
Model = new CognitiveServicesAccountDeploymentModel() {
50-
Name = this.Model,
57+
Model = new CognitiveServicesAccountDeploymentModel()
58+
{
59+
Name = _chatDeployment.Model,
5160
Format = "OpenAI",
52-
Version = this.ModelVersion,
61+
Version = _chatDeployment.ModelVersion
5362
}
5463
},
5564
};
65+
cloudMachine.AddResource(chat);
5666

57-
cloudMachine.AddResource(cognitiveServices);
58-
cloudMachine.AddResource(deployment);
67+
if (_embeddingsDeployment != null)
68+
{
69+
CognitiveServicesAccountDeployment embeddings = new("openai_deployment", "2023-05-01")
70+
{
71+
Parent = cognitiveServices,
72+
Name = $"{cloudMachine.Id}-embedding",
73+
Properties = new CognitiveServicesAccountDeploymentProperties()
74+
{
75+
Model = new CognitiveServicesAccountDeploymentModel()
76+
{
77+
Name = _embeddingsDeployment.Model,
78+
Format = "OpenAI",
79+
Version = _embeddingsDeployment.ModelVersion
80+
}
81+
},
82+
};
83+
84+
// Ensure that additional deployments, are chained using DependsOn.
85+
// The reason is that deployments need to be deployed/created serially.
86+
embeddings.DependsOn.Add(chat);
87+
cloudMachine.AddResource(embeddings);
88+
}
5989
}
6090
}
6191

@@ -72,6 +102,17 @@ public static ChatClient GetOpenAIChatClient(this ClientWorkspace workspace)
72102
return chatClient;
73103
}
74104

105+
public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace workspace)
106+
{
107+
EmbeddingClient embeddingsClient = workspace.Subclients.Get(() =>
108+
{
109+
AzureOpenAIClient aoiaClient = workspace.Subclients.Get(() => CreateAzureOpenAIClient(workspace));
110+
return workspace.CreateEmbeddingsClient(aoiaClient);
111+
});
112+
113+
return embeddingsClient;
114+
}
115+
75116
private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
76117
{
77118
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));
@@ -81,7 +122,7 @@ private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace wo
81122
}
82123
else
83124
{
84-
return new(connection.Endpoint, new ApiKeyCredential(connection.ApiKeyCredential!));
125+
return new(connection.Endpoint, new ApiKeyCredential(connection.ApiKeyCredential!));
85126
}
86127
}
87128

@@ -91,4 +132,11 @@ private static ChatClient CreateChatClient(this ClientWorkspace workspace, Azure
91132
ChatClient chat = client.GetChatClient(connection.Id);
92133
return chat;
93134
}
135+
136+
private static EmbeddingClient CreateEmbeddingsClient(this ClientWorkspace workspace, AzureOpenAIClient client)
137+
{
138+
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(EmbeddingClient));
139+
EmbeddingClient embeddings = client.GetEmbeddingClient(connection.Id);
140+
return embeddings;
141+
}
94142
}

sdk/provisioning/Azure.Provisioning.CloudMachine/src/OFX/CloudMachineWorkspace.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ public override ClientConnectionOptions GetConnectionOptions(Type clientType, st
6464
return new ClientConnectionOptions(new($"https://{this.Id}.openai.azure.com"), Credential);
6565
case "OpenAI.Chat.ChatClient":
6666
return new ClientConnectionOptions(Id);
67+
case "OpenAI.Embeddings.EmbeddingClient":
68+
return new ClientConnectionOptions($"{Id}-embedding");
6769
default:
6870
throw new Exception($"unknown client {clientId}");
6971
}

sdk/provisioning/Azure.Provisioning.CloudMachine/src/OFX/StorageServices.cs

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Azure.Core;
1010
using Azure.Storage.Blobs;
1111
using Azure.Storage.Blobs.Models;
12+
using Azure.Storage.Blobs.Specialized;
1213

1314
namespace Azure.CloudMachine;
1415

@@ -42,28 +43,63 @@ private BlobContainerClient GetContainer(string containerName)
4243
return container;
4344
}
4445

45-
public string UploadJson(object json, string? name = default)
46+
public string UploadJson(object json, string? name = default, bool overwrite = false)
4647
{
4748
BlobContainerClient container = GetDefaultContainer();
4849

4950
if (name == default)
5051
name = $"b{Guid.NewGuid()}";
5152

52-
container.UploadBlob(name, BinaryData.FromObjectAsJson(json));
53+
var client = container.GetBlockBlobClient(name);
54+
var options = new BlobUploadOptions
55+
{
56+
Conditions = overwrite ? null : new BlobRequestConditions { IfNoneMatch = new ETag("*") },
57+
HttpHeaders = new BlobHttpHeaders { ContentType = ContentType.ApplicationJson.ToString() }
58+
};
59+
60+
client.Upload(BinaryData.FromObjectAsJson(json).ToStream(), options);
61+
return name;
62+
}
63+
64+
public string UploadStream(Stream fileStream, string? name = default, bool overwrite = false)
65+
{
66+
BlobContainerClient container = GetDefaultContainer();
67+
68+
if (name == default)
69+
name = $"b{Guid.NewGuid()}";
5370

71+
var client = container.GetBlockBlobClient(name);
72+
var options = new BlobUploadOptions
73+
{
74+
Conditions = overwrite ? null : new BlobRequestConditions { IfNoneMatch = new ETag("*") },
75+
HttpHeaders = new BlobHttpHeaders { ContentType = ContentType.ApplicationOctetStream.ToString() }
76+
};
77+
78+
client.Upload(fileStream, options);
5479
return name;
5580
}
56-
public string UploadBytes(BinaryData bytes, string? name = default)
81+
82+
public string UploadBinaryData(BinaryData data, string? name = default, bool overwrite = false)
5783
{
5884
BlobContainerClient container = GetDefaultContainer();
59-
if (name == default) name = $"b{Guid.NewGuid()}";
60-
container.UploadBlob(name, bytes);
85+
if (name == default)
86+
name = $"b{Guid.NewGuid()}";
87+
88+
var client = container.GetBlockBlobClient(name);
89+
var options = new BlobUploadOptions
90+
{
91+
Conditions = overwrite ? null : new BlobRequestConditions { IfNoneMatch = new ETag("*") },
92+
HttpHeaders = new BlobHttpHeaders { ContentType = ContentType.ApplicationOctetStream.ToString() }
93+
};
94+
95+
client.Upload(data.ToStream(), options);
6196
return name;
6297
}
63-
public string UploadBytes(byte[] bytes, string? name = default)
64-
=> UploadBytes(BinaryData.FromBytes(bytes), name);
65-
public string UploadBytes(ReadOnlyMemory<byte> bytes, string? name = default)
66-
=> UploadBytes(BinaryData.FromBytes(bytes), name);
98+
99+
public string UploadBytes(byte[] bytes, string? name = default, bool overwrite = false)
100+
=> UploadBinaryData(BinaryData.FromBytes(bytes), name, overwrite);
101+
public string UploadBytes(ReadOnlyMemory<byte> bytes, string? name = default, bool overwrite = false)
102+
=> UploadBinaryData(BinaryData.FromBytes(bytes), name, overwrite);
67103

68104
public BinaryData DownloadBlob(string path)
69105
{

sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ public void Provisioning(string[] args)
2424
if (CloudMachineInfrastructure.Configure(args, (cm) =>
2525
{
2626
cm.AddFeature(new KeyVaultFeature());
27-
cm.AddFeature(new OpenAIFeature("gpt-35-turbo", "0125"));
27+
cm.AddFeature(new OpenAIFeature(new AiModel("gpt-35-turbo", "0125"), new AiModel("text-embedding-ada-002", "2")));
2828
}))
2929
return;
3030

3131
CloudMachineWorkspace cm = new();
3232
Console.WriteLine(cm.Id);
33+
var embeddings = cm.GetOpenAIEmbeddingsClient();
3334
}
3435

3536
[Ignore("no recordings yet")]
@@ -71,7 +72,7 @@ public void OpenAI(string[] args)
7172
{
7273
if (CloudMachineInfrastructure.Configure(args, (cm) =>
7374
{
74-
cm.AddFeature(new OpenAIFeature("gpt-35-turbo", "0125"));
75+
cm.AddFeature(new OpenAIFeature(new AiModel("gpt-35-turbo", "0125")));
7576
}))
7677
return;
7778

@@ -137,7 +138,7 @@ public void Demo(string[] args)
137138
CloudMachineClient cm = new();
138139

139140
// setup
140-
cm.Messaging.WhenMessageReceived((string message) => cm.Storage.UploadBytes(BinaryData.FromString(message)));
141+
cm.Messaging.WhenMessageReceived((string message) => cm.Storage.UploadBinaryData(BinaryData.FromString(message)));
141142
cm.Storage.WhenBlobUploaded((StorageFile file) =>
142143
{
143144
var content = file.Download();

0 commit comments

Comments
 (0)