Skip to content

Commit 0083cf6

Browse files
committed
feat: add baichuan models
1 parent 41610b4 commit 0083cf6

15 files changed

+314
-99
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
namespace Cnblogs.DashScope.Sdk.BaiChuan;
2+
3+
/// <summary>
4+
/// BaiChuan2 model, supports prompt and message format.
5+
/// </summary>
6+
public enum BaiChuan2Llm
7+
{
8+
/// <summary>
9+
/// baichuan2-7b-chat-v1
10+
/// </summary>
11+
BaiChuan2_7BChatV1 = 1,
12+
13+
/// <summary>
14+
/// baichuan2-13b-chat-v1
15+
/// </summary>
16+
BaiChuan2_13BChatV1 = 2
17+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
namespace Cnblogs.DashScope.Sdk.BaiChuan;
2+
3+
/// <summary>
4+
/// Supported baichuan model: https://help.aliyun.com/zh/dashscope/developer-reference/api-details-2
5+
/// </summary>
6+
public enum BaiChuanLlm
7+
{
8+
/// <summary>
9+
/// baichuan-7b-v1
10+
/// </summary>
11+
BaiChuan7B = 1
12+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Cnblogs.DashScope.Sdk.Internals;
2+
3+
namespace Cnblogs.DashScope.Sdk.BaiChuan;
4+
5+
internal static class BaiChuanLlmName
6+
{
7+
public static string GetModelName(this BaiChuanLlm llm)
8+
{
9+
return llm switch
10+
{
11+
BaiChuanLlm.BaiChuan7B => "baichuan-7b-v1",
12+
_ => ThrowHelper.UnknownModelName(nameof(llm), llm)
13+
};
14+
}
15+
16+
public static string GetModelName(this BaiChuan2Llm llm)
17+
{
18+
return llm switch
19+
{
20+
BaiChuan2Llm.BaiChuan2_7BChatV1 => "baichuan2-7b-chat-v1",
21+
BaiChuan2Llm.BaiChuan2_13BChatV1 => "baichuan2-13b-chat-v1",
22+
_ => ThrowHelper.UnknownModelName(nameof(llm), llm)
23+
};
24+
}
25+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
namespace Cnblogs.DashScope.Sdk.BaiChuan;
2+
3+
/// <summary>
4+
/// BaiChuan LLM generation apis, doc: https://help.aliyun.com/zh/dashscope/developer-reference/api-details-2
5+
/// </summary>
6+
public static class BaiChuanTextGenerationApi
7+
{
8+
/// <summary>
9+
/// Get text completion from baichuan model.
10+
/// </summary>
11+
/// <param name="client">The <see cref="IDashScopeClient"/>.</param>
12+
/// <param name="llm">The llm to use.</param>
13+
/// <param name="prompt">The prompt to generate completion from.</param>
14+
/// <returns></returns>
15+
public static Task<ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> GetBaiChuanTextCompletionAsync(
16+
this IDashScopeClient client,
17+
BaiChuanLlm llm,
18+
string prompt)
19+
{
20+
return client.GetBaiChuanTextCompletionAsync(llm.GetModelName(), prompt);
21+
}
22+
23+
/// <summary>
24+
/// Get text completion from baichuan model.
25+
/// </summary>
26+
/// <param name="client">The <see cref="IDashScopeClient"/>.</param>
27+
/// <param name="llm">The llm to use.</param>
28+
/// <param name="prompt">The prompt to generate completion from.</param>
29+
/// <returns></returns>
30+
public static Task<ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> GetBaiChuanTextCompletionAsync(
31+
this IDashScopeClient client,
32+
string llm,
33+
string prompt)
34+
{
35+
return client.GetTextCompletionAsync(
36+
new ModelRequest<TextGenerationInput, TextGenerationParameters>
37+
{
38+
Model = llm,
39+
Input = new TextGenerationInput { Prompt = prompt },
40+
Parameters = null
41+
});
42+
}
43+
44+
/// <summary>
45+
/// Get text completion from baichuan model.
46+
/// </summary>
47+
/// <param name="client">The <see cref="IDashScopeClient"/>.</param>
48+
/// <param name="llm">The model name.</param>
49+
/// <param name="messages">The context messages.</param>
50+
/// <param name="resultFormat">Can be 'text' or 'message', defaults to 'text'. Call <see cref="ResultFormats"/> to get available options.</param>
51+
/// <returns></returns>
52+
public static Task<ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> GetBaiChuanTextCompletionAsync(
53+
this IDashScopeClient client,
54+
BaiChuan2Llm llm,
55+
IEnumerable<ChatMessage> messages,
56+
string? resultFormat = null)
57+
{
58+
return client.GetBaiChuanTextCompletionAsync(llm.GetModelName(), messages, resultFormat);
59+
}
60+
61+
/// <summary>
62+
/// Get text completion from baichuan model.
63+
/// </summary>
64+
/// <param name="client">The <see cref="IDashScopeClient"/>.</param>
65+
/// <param name="llm">The model name.</param>
66+
/// <param name="messages">The context messages.</param>
67+
/// <param name="resultFormat">Can be 'text' or 'message', defaults to 'text'. Call <see cref="ResultFormats"/> to get available options.</param>
68+
/// <returns></returns>
69+
public static Task<ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> GetBaiChuanTextCompletionAsync(
70+
this IDashScopeClient client,
71+
string llm,
72+
IEnumerable<ChatMessage> messages,
73+
string? resultFormat = null)
74+
{
75+
return client.GetTextCompletionAsync(
76+
new ModelRequest<TextGenerationInput, TextGenerationParameters>
77+
{
78+
Model = llm,
79+
Input = new TextGenerationInput { Messages = messages },
80+
Parameters = string.IsNullOrEmpty(resultFormat) == false
81+
? new TextGenerationParameters { ResultFormat = resultFormat }
82+
: null
83+
});
84+
}
85+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System.Diagnostics.CodeAnalysis;
2+
3+
namespace Cnblogs.DashScope.Sdk.Internals;
4+
5+
internal static class ThrowHelper
6+
{
7+
[DoesNotReturn]
8+
public static string UnknownModelName(string argumentName, object value)
9+
{
10+
throw new ArgumentOutOfRangeException(
11+
argumentName,
12+
value,
13+
"Unknown model type, please use the overload that accepts a string ‘model’ parameter.");
14+
}
15+
}
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Cnblogs.DashScope.Sdk.Llama2;
1+
using Cnblogs.DashScope.Sdk.Internals;
2+
3+
namespace Cnblogs.DashScope.Sdk.Llama2;
24

35
internal static class Llama2ModelNames
46
{
@@ -8,10 +10,7 @@ public static string GetModelName(this Llama2Model model)
810
{
911
Llama2Model.Chat7Bv2 => "llama2-7b-chat-v2",
1012
Llama2Model.Chat13Bv2 => "llama2-13b-chat-v2",
11-
_ => throw new ArgumentOutOfRangeException(
12-
nameof(model),
13-
model,
14-
"Unknown model type, please use the overload that accepts a string ‘model’ parameter.")
13+
_ => ThrowHelper.UnknownModelName(nameof(model), model)
1514
};
1615
}
1716
}

src/Cnblogs.DashScope.Sdk/QWen/QWenLlmNames.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Cnblogs.DashScope.Sdk.QWen;
1+
using Cnblogs.DashScope.Sdk.Internals;
2+
3+
namespace Cnblogs.DashScope.Sdk.QWen;
24

35
internal static class QWenLlmNames
46
{
@@ -19,10 +21,7 @@ public static string GetModelName(this QWenLlm llm)
1921
QWenLlm.QWen7BChat => "qwen-7b-chat",
2022
QWenLlm.QWen1_8BLongContextChat => "qwen-1.8b-longcontext-chat",
2123
QWenLlm.QWen1_8Chat => "qwen-1.8b-chat",
22-
_ => throw new ArgumentOutOfRangeException(
23-
nameof(llm),
24-
llm,
25-
"Unknown model type, please use the overload that accepts a string ‘model’ parameter.")
24+
_ => ThrowHelper.UnknownModelName(nameof(llm), llm)
2625
};
2726
}
2827
}

src/Cnblogs.DashScope.Sdk/QWenMultimodal/QWenMultimodalModelNames.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Cnblogs.DashScope.Sdk.QWenMultimodal;
1+
using Cnblogs.DashScope.Sdk.Internals;
2+
3+
namespace Cnblogs.DashScope.Sdk.QWenMultimodal;
24

35
internal static class QWenMultimodalModelNames
46
{
@@ -12,10 +14,7 @@ public static string GetModelName(this QWenMultimodalModel multimodalModel)
1214
QWenMultimodalModel.QWenVlV1 => "qwen-vl-v1",
1315
QWenMultimodalModel.QWenVlChatV1 => "qwen-vl-chat-v1",
1416
QWenMultimodalModel.QWenAudioChat => "qwen-audio-chat",
15-
_ => throw new ArgumentOutOfRangeException(
16-
nameof(multimodalModel),
17-
multimodalModel,
18-
"Unknown model type, please use the overload that accepts a string ‘model’ parameter.")
17+
_ => ThrowHelper.UnknownModelName(nameof(multimodalModel), multimodalModel)
1918
};
2019
}
2120
}
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Cnblogs.DashScope.Sdk.TextEmbedding;
1+
using Cnblogs.DashScope.Sdk.Internals;
2+
3+
namespace Cnblogs.DashScope.Sdk.TextEmbedding;
24

35
internal static class TextEmbeddingModelNames
46
{
@@ -8,10 +10,7 @@ public static string GetModelName(this TextEmbeddingModel model)
810
{
911
TextEmbeddingModel.TextEmbeddingV1 => "text-embedding-v1",
1012
TextEmbeddingModel.TextEmbeddingV2 => "text-embedding-v2",
11-
_ => throw new ArgumentOutOfRangeException(
12-
nameof(model),
13-
model,
14-
"Unknown model type, please use the overload with ‘model’ as a string type.")
13+
_ => ThrowHelper.UnknownModelName(nameof(model), model)
1514
};
1615
}
1716
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using Cnblogs.DashScope.Sdk.BaiChuan;
2+
using Cnblogs.DashScope.Sdk.UnitTests.Utils;
3+
using NSubstitute;
4+
5+
namespace Cnblogs.DashScope.Sdk.UnitTests;
6+
7+
public class BaiChuanTextGenerationTests
8+
{
9+
[Fact]
10+
public async Task BaiChuanTextGeneration_UseEnum_SuccessAsync()
11+
{
12+
// Arrange
13+
var client = Substitute.For<IDashScopeClient>();
14+
15+
// Act
16+
_ = await client.GetBaiChuanTextCompletionAsync(BaiChuanLlm.BaiChuan7B, Cases.Prompt);
17+
18+
// Assert
19+
_ = await client.Received().GetTextCompletionAsync(
20+
Arg.Is<ModelRequest<TextGenerationInput, TextGenerationParameters>>(
21+
s => s.Model == "baichuan-7b-v1" && s.Input.Prompt == Cases.Prompt && s.Parameters == null));
22+
}
23+
24+
[Fact]
25+
public async Task BaiChuanTextGeneration_CustomModel_SuccessAsync()
26+
{
27+
// Arrange
28+
var client = Substitute.For<IDashScopeClient>();
29+
30+
// Act
31+
_ = await client.GetBaiChuanTextCompletionAsync(BaiChuanLlm.BaiChuan7B, Cases.Prompt);
32+
33+
// Assert
34+
_ = await client.Received().GetTextCompletionAsync(
35+
Arg.Is<ModelRequest<TextGenerationInput, TextGenerationParameters>>(
36+
s => s.Model == "baichuan-7b-v1" && s.Input.Prompt == Cases.Prompt && s.Parameters == null));
37+
}
38+
39+
[Fact]
40+
public async Task BaiChuan2TextGeneration_UseEnum_SuccessAsync()
41+
{
42+
// Arrange
43+
var client = Substitute.For<IDashScopeClient>();
44+
45+
// Act
46+
_ = await client.GetBaiChuanTextCompletionAsync(
47+
BaiChuan2Llm.BaiChuan2_13BChatV1,
48+
Cases.TextMessages,
49+
ResultFormats.Message);
50+
51+
// Assert
52+
_ = await client.Received().GetTextCompletionAsync(
53+
Arg.Is<ModelRequest<TextGenerationInput, TextGenerationParameters>>(
54+
s => s.Model == "baichuan2-13b-chat-v1"
55+
&& s.Input.Messages == Cases.TextMessages
56+
&& s.Parameters != null
57+
&& s.Parameters.ResultFormat == ResultFormats.Message));
58+
}
59+
60+
[Fact]
61+
public async Task BaiChuan2TextGeneration_CustomModel_SuccessAsync()
62+
{
63+
// Arrange
64+
var client = Substitute.For<IDashScopeClient>();
65+
66+
// Act
67+
_ = await client.GetBaiChuanTextCompletionAsync(
68+
Cases.CustomModelName,
69+
Cases.TextMessages,
70+
ResultFormats.Message);
71+
72+
// Assert
73+
_ = await client.Received().GetTextCompletionAsync(
74+
Arg.Is<ModelRequest<TextGenerationInput, TextGenerationParameters>>(
75+
s => s.Model == Cases.CustomModelName
76+
&& s.Input.Messages == Cases.TextMessages
77+
&& s.Parameters != null
78+
&& s.Parameters.ResultFormat == ResultFormats.Message));
79+
}
80+
}

0 commit comments

Comments
 (0)