Skip to content

Commit a1bf53c

Browse files
committed
Add DelegatingAIFunction (#6565)
* Add DelegatingAIFunction To simplify scenarios where someone wants to augment an existing AIFunction's behavior, tweak what one of its properties returns, etc. * Address PR feedback
1 parent f78d287 commit a1bf53c

File tree

3 files changed

+205
-0
lines changed

3 files changed

+205
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Reflection;
7+
using System.Text.Json;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Microsoft.Shared.Diagnostics;
11+
12+
#pragma warning disable SA1202 // Elements should be ordered by access
13+
14+
namespace Microsoft.Extensions.AI;
15+
16+
/// <summary>
17+
/// Provides an optional base class for an <see cref="AIFunction"/> that passes through calls to another instance.
18+
/// </summary>
19+
public class DelegatingAIFunction : AIFunction
20+
{
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="DelegatingAIFunction"/> class as a wrapper around <paramref name="innerFunction"/>.
23+
/// </summary>
24+
/// <param name="innerFunction">The inner AI function to which all calls are delegated by default.</param>
25+
/// <exception cref="ArgumentNullException"><paramref name="innerFunction"/> is <see langword="null"/>.</exception>
26+
protected DelegatingAIFunction(AIFunction innerFunction)
27+
{
28+
InnerFunction = Throw.IfNull(innerFunction);
29+
}
30+
31+
/// <summary>Gets the inner <see cref="AIFunction" />.</summary>
32+
protected AIFunction InnerFunction { get; }
33+
34+
/// <inheritdoc />
35+
public override string Name => InnerFunction.Name;
36+
37+
/// <inheritdoc />
38+
public override string Description => InnerFunction.Description;
39+
40+
/// <inheritdoc />
41+
public override JsonElement JsonSchema => InnerFunction.JsonSchema;
42+
43+
/// <inheritdoc />
44+
public override JsonElement? ReturnJsonSchema => InnerFunction.ReturnJsonSchema;
45+
46+
/// <inheritdoc />
47+
public override JsonSerializerOptions JsonSerializerOptions => InnerFunction.JsonSerializerOptions;
48+
49+
/// <inheritdoc />
50+
public override MethodInfo? UnderlyingMethod => InnerFunction.UnderlyingMethod;
51+
52+
/// <inheritdoc />
53+
public override IReadOnlyDictionary<string, object?> AdditionalProperties => InnerFunction.AdditionalProperties;
54+
55+
/// <inheritdoc />
56+
public override string ToString() => InnerFunction.ToString();
57+
58+
/// <inheritdoc />
59+
protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
60+
InnerFunction.InvokeAsync(arguments, cancellationToken);
61+
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,58 @@
13551355
}
13561356
]
13571357
},
1358+
{
1359+
"Type": "class Microsoft.Extensions.AI.DelegatingAIFunction : Microsoft.Extensions.AI.AIFunction",
1360+
"Stage": "Stable",
1361+
"Methods": [
1362+
{
1363+
"Member": "Microsoft.Extensions.AI.DelegatingAIFunction.DelegatingAIFunction(Microsoft.Extensions.AI.AIFunction innerFunction);",
1364+
"Stage": "Stable"
1365+
},
1366+
{
1367+
"Member": "override System.Threading.Tasks.ValueTask<object?> Microsoft.Extensions.AI.DelegatingAIFunction.InvokeCoreAsync(Microsoft.Extensions.AI.AIFunctionArguments arguments, System.Threading.CancellationToken cancellationToken);",
1368+
"Stage": "Stable"
1369+
},
1370+
{
1371+
"Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.ToString();",
1372+
"Stage": "Experimental"
1373+
}
1374+
],
1375+
"Properties": [
1376+
{
1377+
"Member": "Microsoft.Extensions.AI.AIFunction Microsoft.Extensions.AI.DelegatingAIFunction.InnerFunction { get; }",
1378+
"Stage": "Stable"
1379+
},
1380+
{
1381+
"Member": "override System.Collections.Generic.IReadOnlyDictionary<string, object?> Microsoft.Extensions.AI.DelegatingAIFunction.AdditionalProperties { get; }",
1382+
"Stage": "Stable"
1383+
},
1384+
{
1385+
"Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.Description { get; }",
1386+
"Stage": "Stable"
1387+
},
1388+
{
1389+
"Member": "override System.Text.Json.JsonElement Microsoft.Extensions.AI.DelegatingAIFunction.JsonSchema { get; }",
1390+
"Stage": "Stable"
1391+
},
1392+
{
1393+
"Member": "override System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DelegatingAIFunction.JsonSerializerOptions { get; }",
1394+
"Stage": "Stable"
1395+
},
1396+
{
1397+
"Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.Name { get; }",
1398+
"Stage": "Stable"
1399+
},
1400+
{
1401+
"Member": "override System.Text.Json.JsonElement? Microsoft.Extensions.AI.DelegatingAIFunction.ReturnJsonSchema { get; }",
1402+
"Stage": "Stable"
1403+
},
1404+
{
1405+
"Member": "override System.Reflection.MethodInfo? Microsoft.Extensions.AI.DelegatingAIFunction.UnderlyingMethod { get; }",
1406+
"Stage": "Stable"
1407+
}
1408+
]
1409+
},
13581410
{
13591411
"Type": "class Microsoft.Extensions.AI.DelegatingChatClient : Microsoft.Extensions.AI.IChatClient, System.IDisposable",
13601412
"Stage": "Stable",
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Reflection;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Xunit;
9+
10+
namespace Microsoft.Extensions.AI;
11+
12+
public class DelegatingAIFunctionTests
13+
{
14+
[Fact]
15+
public void Constructor_NullInnerFunction_ThrowsArgumentNullException()
16+
{
17+
Assert.Throws<ArgumentNullException>("innerFunction", () => new DerivedFunction(null!));
18+
}
19+
20+
[Fact]
21+
public void DefaultOverrides_DelegateToInnerFunction()
22+
{
23+
AIFunction expected = AIFunctionFactory.Create(() => 42);
24+
DerivedFunction actual = new(expected);
25+
26+
Assert.Same(expected, actual.InnerFunction);
27+
Assert.Equal(expected.Name, actual.Name);
28+
Assert.Equal(expected.Description, actual.Description);
29+
Assert.Equal(expected.JsonSchema, actual.JsonSchema);
30+
Assert.Equal(expected.ReturnJsonSchema, actual.ReturnJsonSchema);
31+
Assert.Same(expected.JsonSerializerOptions, actual.JsonSerializerOptions);
32+
Assert.Same(expected.UnderlyingMethod, actual.UnderlyingMethod);
33+
Assert.Same(expected.AdditionalProperties, actual.AdditionalProperties);
34+
Assert.Equal(expected.ToString(), actual.ToString());
35+
}
36+
37+
private sealed class DerivedFunction(AIFunction innerFunction) : DelegatingAIFunction(innerFunction)
38+
{
39+
public new AIFunction InnerFunction => base.InnerFunction;
40+
}
41+
42+
[Fact]
43+
public void Virtuals_AllOverridden()
44+
{
45+
Assert.All(typeof(DelegatingAIFunction).GetMembers(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance), m =>
46+
{
47+
switch (m)
48+
{
49+
case MethodInfo methodInfo when methodInfo.IsVirtual && methodInfo.Name is not ("Finalize" or "Equals" or "GetHashCode"):
50+
Assert.True(methodInfo.DeclaringType == typeof(DelegatingAIFunction), $"{methodInfo.Name} not overridden");
51+
break;
52+
53+
case PropertyInfo propertyInfo when propertyInfo.GetMethod?.IsVirtual is true:
54+
Assert.True(propertyInfo.DeclaringType == typeof(DelegatingAIFunction), $"{propertyInfo.Name} not overridden");
55+
break;
56+
}
57+
});
58+
}
59+
60+
[Fact]
61+
public async Task OverriddenInvocation_SuccessfullyInvoked()
62+
{
63+
bool innerInvoked = false;
64+
AIFunction inner = AIFunctionFactory.Create(int () =>
65+
{
66+
innerInvoked = true;
67+
throw new Exception("uh oh");
68+
}, "TestFunction", "A test function for DelegatingAIFunction");
69+
70+
AIFunction actual = new OverridesInvocation(inner, (args, ct) => new ValueTask<object?>(84));
71+
72+
Assert.Equal(inner.Name, actual.Name);
73+
Assert.Equal(inner.Description, actual.Description);
74+
Assert.Equal(inner.JsonSchema, actual.JsonSchema);
75+
Assert.Equal(inner.ReturnJsonSchema, actual.ReturnJsonSchema);
76+
Assert.Same(inner.JsonSerializerOptions, actual.JsonSerializerOptions);
77+
Assert.Same(inner.UnderlyingMethod, actual.UnderlyingMethod);
78+
Assert.Same(inner.AdditionalProperties, actual.AdditionalProperties);
79+
Assert.Equal(inner.ToString(), actual.ToString());
80+
81+
object? result = await actual.InvokeAsync(new(), CancellationToken.None);
82+
Assert.Contains("84", result?.ToString());
83+
84+
Assert.False(innerInvoked);
85+
}
86+
87+
private sealed class OverridesInvocation(AIFunction innerFunction, Func<AIFunctionArguments, CancellationToken, ValueTask<object?>> invokeAsync) : DelegatingAIFunction(innerFunction)
88+
{
89+
protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
90+
invokeAsync(arguments, cancellationToken);
91+
}
92+
}

0 commit comments

Comments
 (0)