Skip to content

Commit d88ed81

Browse files
committed
Support for declarative build function callbacks in the Spring container
-- add FunctionCallbackMethodProcessor to support parsing. -- add the FunctionCalling annotation to declare -- add FunctionCallbackMethodProcessorIT as test Dependent on spring-projects#1099
1 parent e7e2e92 commit d88ed81

File tree

4 files changed

+271
-0
lines changed

4 files changed

+271
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.function;
17+
18+
import org.apache.commons.logging.Log;
19+
import org.apache.commons.logging.LogFactory;
20+
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
21+
import org.springframework.aop.scope.ScopedObject;
22+
import org.springframework.aop.scope.ScopedProxyUtils;
23+
import org.springframework.beans.BeansException;
24+
import org.springframework.beans.factory.BeanInitializationException;
25+
import org.springframework.beans.factory.SmartInitializingSingleton;
26+
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
27+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
28+
import org.springframework.core.MethodIntrospector;
29+
import org.springframework.core.annotation.AnnotatedElementUtils;
30+
import org.springframework.core.annotation.AnnotationUtils;
31+
import org.springframework.lang.NonNull;
32+
import org.springframework.lang.Nullable;
33+
import org.springframework.stereotype.Component;
34+
import org.springframework.util.Assert;
35+
import org.springframework.util.ClassUtils;
36+
import org.springframework.util.CollectionUtils;
37+
import org.springframework.util.ReflectionUtils;
38+
39+
import java.lang.reflect.Method;
40+
import java.lang.reflect.Modifier;
41+
import java.util.Collections;
42+
import java.util.Map;
43+
import java.util.Set;
44+
import java.util.concurrent.ConcurrentHashMap;
45+
46+
/**
47+
* {@link BeanFactoryPostProcessor} that processes {@link FunctionCalling} annotations.
48+
* <p>
49+
* <p>Any such annotated method is registered as a {@link MethodFunctionCallback} bean in the
50+
* application context.
51+
* <p>
52+
* <p>Processing of {@code @FunctionCalling} annotations can be customized through the
53+
* {@link FunctionCalling} annotation.
54+
* <p>
55+
*
56+
* @see FunctionCalling
57+
* @see MethodFunctionCallback
58+
* @author kamosama
59+
*/
60+
public class FunctionCallbackMethodProcessor
61+
implements SmartInitializingSingleton, BeanFactoryPostProcessor {
62+
63+
protected final Log logger = LogFactory.getLog(getClass());
64+
65+
private final Set<Class<?>> nonAnnotatedClasses = Collections.newSetFromMap(new ConcurrentHashMap<>(64));
66+
67+
@Nullable
68+
private ConfigurableListableBeanFactory beanFactory;
69+
70+
71+
@Override
72+
public void afterSingletonsInstantiated() {
73+
Assert.state(this.beanFactory != null, "No ConfigurableListableBeanFactory set");
74+
75+
String[] beanNames = beanFactory.getBeanNamesForType(Object.class);
76+
77+
for (String beanName : beanNames) {
78+
if (ScopedProxyUtils.isScopedTarget(beanName)) {
79+
continue;
80+
}
81+
Class<?> type = null;
82+
try {
83+
type = AutoProxyUtils.determineTargetClass(beanFactory, beanName);
84+
} catch (Throwable ex) {
85+
// An unresolvable bean type, probably from a lazy bean - let's ignore it.
86+
if (logger.isDebugEnabled()) {
87+
logger.debug("Could not resolve target class for bean with name '" + beanName + "'", ex);
88+
}
89+
}
90+
if (type == null) {
91+
continue;
92+
}
93+
if (ScopedObject.class.isAssignableFrom(type)) {
94+
try {
95+
Class<?> targetClass = AutoProxyUtils.determineTargetClass(
96+
beanFactory, ScopedProxyUtils.getTargetBeanName(beanName));
97+
if (targetClass != null) {
98+
type = targetClass;
99+
}
100+
} catch (Throwable ex) {
101+
// An invalid scoped proxy arrangement - let's ignore it.
102+
if (logger.isDebugEnabled()) {
103+
logger.debug("Could not resolve target bean for scoped proxy '" + beanName + "'", ex);
104+
}
105+
}
106+
}
107+
try {
108+
processBean(beanName, type);
109+
} catch (Throwable ex) {
110+
throw new BeanInitializationException("Failed to process @FunctionCalling " +
111+
"annotation on bean with name '" + beanName + "'", ex);
112+
}
113+
}
114+
}
115+
116+
private void processBean(final String beanName, final Class<?> targetType) {
117+
Assert.state(this.beanFactory != null, "No ConfigurableListableBeanFactory set");
118+
119+
if (!this.nonAnnotatedClasses.contains(targetType)
120+
&& AnnotationUtils.isCandidateClass(targetType, FunctionCalling.class)
121+
&& !isSpringContainerClass(targetType)
122+
) {
123+
124+
Map<Method, FunctionCalling> annotatedMethods = null;
125+
try {
126+
annotatedMethods = MethodIntrospector.selectMethods(targetType,
127+
(MethodIntrospector.MetadataLookup<FunctionCalling>) method ->
128+
AnnotatedElementUtils.findMergedAnnotation(method, FunctionCalling.class));
129+
} catch (Throwable ex) {
130+
// An unresolvable type in a method signature, probably from a lazy bean - let's ignore it.
131+
if (logger.isDebugEnabled()) {
132+
logger.debug("Could not resolve methods for bean with name '" + beanName + "'", ex);
133+
}
134+
}
135+
136+
if (CollectionUtils.isEmpty(annotatedMethods)) {
137+
this.nonAnnotatedClasses.add(targetType);
138+
if (logger.isTraceEnabled()) {
139+
logger.trace("No @FunctionCalling annotations found on bean class: " + targetType.getName());
140+
}
141+
} else {
142+
// Non-empty set of methods
143+
annotatedMethods.forEach((method, annotation) -> {
144+
String name = annotation.name().isEmpty() ? method.getName() : annotation.name();
145+
ReflectionUtils.makeAccessible(method);
146+
var functionObject = Modifier.isStatic(method.getModifiers()) ? null : beanFactory.getBean(beanName);
147+
MethodFunctionCallback callback = MethodFunctionCallback.builder()
148+
.withFunctionObject(functionObject)
149+
.withMethod(method)
150+
.withDescription(annotation.description())
151+
.build();
152+
beanFactory.registerSingleton(name, callback);
153+
});
154+
155+
if (logger.isDebugEnabled()) {
156+
logger.debug(annotatedMethods.size() + " @FunctionCalling methods processed on bean '" +
157+
beanName + "': " + annotatedMethods);
158+
}
159+
}
160+
}
161+
}
162+
163+
/**
164+
* Determine whether the given class is an {@code org.springframework}
165+
* bean class that is not annotated as a user or test {@link Component}...
166+
* which indicates that there is no {@link FunctionCalling} to be found there.
167+
*/
168+
private static boolean isSpringContainerClass(Class<?> clazz) {
169+
return (clazz.getName().startsWith("org.springframework.") &&
170+
!AnnotatedElementUtils.isAnnotated(ClassUtils.getUserClass(clazz), Component.class));
171+
}
172+
173+
@Override
174+
public void postProcessBeanFactory(@NonNull ConfigurableListableBeanFactory beanFactory) throws BeansException {
175+
this.beanFactory = beanFactory;
176+
}
177+
178+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.function;
17+
18+
import java.lang.annotation.*;
19+
20+
/**
21+
* Annotation to indicate that a method is a AI function calling.
22+
*
23+
* @see FunctionCallbackMethodProcessor
24+
* @author kamosama
25+
*/
26+
@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
27+
@Retention(RetentionPolicy.RUNTIME)
28+
@Documented
29+
public @interface FunctionCalling {
30+
31+
String name() default "";
32+
33+
String description();
34+
35+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.springframework1.ai.model.function;
2+
3+
import org.springframework.ai.model.function.FunctionCalling;
4+
5+
import java.time.LocalDateTime;
6+
7+
public class FunctionCallConfig {
8+
@FunctionCalling(name = "dateTime", description = "get the current date and time")
9+
public String dateTime(String location) {
10+
return location + " dateTime:" + LocalDateTime.now();
11+
}
12+
13+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.function;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.slf4j.Logger;
20+
import org.slf4j.LoggerFactory;
21+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
22+
23+
public class FunctionCallbackMethodProcessorIT {
24+
25+
private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackMethodProcessorIT.class);
26+
27+
28+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
29+
.withBean(FunctionCallbackMethodProcessor.class)
30+
// The "1" is added to the package path for compatibility with the !isSpringContainerClass method.
31+
.withBean(org.springframework1.ai.model.function.FunctionCallConfig.class);
32+
33+
@Test
34+
public void testFunctionCallbackMethodProcessor() {
35+
contextRunner.run(context -> {
36+
FunctionCallback functionCallback = context.getBean(FunctionCallback.class);
37+
logger.info("FunctionCallback: name:{}, description:{}",
38+
functionCallback.getName(), functionCallback.getDescription());
39+
String result = functionCallback.call("{\"location\":\"New York\"}");
40+
logger.info("Result: {}", result);
41+
assert result.contains("New York");
42+
});
43+
}
44+
45+
}

0 commit comments

Comments
 (0)