Skip to content

Commit 1dce478

Browse files
add python-langchain-tools generator to generate langchain tools from spec
1 parent 6f3daca commit 1dce478

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+11390
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
generatorName: python-langchain-tools
2+
outputDir: samples/client/petstore/python/langchain/tools
3+
inputSpec: modules/openapi-generator/src/test/resources/3_0/petstore.yaml
4+
templateDir: modules/openapi-generator/src/main/resources/python-langchain-tools
5+
additionalProperties:
6+
hideGenerationTimestamp: "true"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package org.openapitools.codegen.languages;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.nio.charset.StandardCharsets;
6+
import java.nio.file.Files;
7+
import java.nio.file.StandardOpenOption;
8+
9+
import org.openapitools.codegen.SupportingFile;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
public class PythonLangchainToolsClientCodegen extends PythonClientCodegen {
14+
public static final String PROJECT_NAME = "projectName";
15+
16+
private final Logger LOGGER = LoggerFactory.getLogger(PythonLangchainToolsClientCodegen.class);
17+
18+
public PythonLangchainToolsClientCodegen() {
19+
super();
20+
apiTemplateFiles.put("api_tools.mustache", "_tools.py");
21+
supportingFiles.add(new SupportingFile("all_tools.mustache", "", "all_tools.py"));
22+
}
23+
24+
@Override
25+
public String getName() {
26+
return "python-langchain-tools";
27+
}
28+
29+
@Override
30+
public String getHelp() {
31+
return "Generates a Python client for LangChain agent tools, grouped by tags.";
32+
}
33+
34+
@Override
35+
public String apiFileFolder() {
36+
return outputFolder + File.separator + apiPackage().replace('.', File.separatorChar);
37+
}
38+
39+
@Override
40+
public String modelFileFolder() {
41+
return outputFolder + File.separator + modelPackage().replace('.', File.separatorChar);
42+
}
43+
44+
@Override
45+
public void processOpts() {
46+
super.processOpts();
47+
48+
// Now that packageName is processed and available, we can add the supporting file
49+
// for the tools package __init__.py in the correct directory.
50+
final String toolsPackagePath = packageName.replace(".", File.separator) + File.separator + "tools";
51+
supportingFiles.add(new SupportingFile("init.mustache", toolsPackagePath, "__init__.py"));
52+
}
53+
54+
@Override
55+
public boolean isEnablePostProcessFile() {
56+
return true;
57+
}
58+
59+
/**
60+
* This is a way to add a dependency without duplicating the parent template.
61+
*/
62+
@Override
63+
public void postProcessFile(File file, String fileType) {
64+
super.postProcessFile(file, fileType);
65+
if (file == null) {
66+
return;
67+
}
68+
final String filename = file.getName();
69+
if ("requirements.txt".equals(filename)) {
70+
try {
71+
String langchainDep = "\nlangchain >=0.3, <0.4\n";
72+
Files.write(file.toPath(), langchainDep.getBytes(StandardCharsets.UTF_8), StandardOpenOption.APPEND);
73+
} catch (IOException e) {
74+
throw new RuntimeException("Unable to write to requirements.txt", e);
75+
}
76+
}
77+
}
78+
79+
/**
80+
* Overriding this method to change the output location of our tool files.
81+
*/
82+
@Override
83+
public String apiFilename(String templateName, String tag) {
84+
final String originalFilename = super.apiFilename(templateName, tag);
85+
86+
if ("api_tools.mustache".equals(templateName)) {
87+
return originalFilename.replace(apiPackage().replace(".", File.separator),
88+
packageName.replace(".", File.separator) + File.separator + "tools");
89+
}
90+
91+
return originalFilename;
92+
}
93+
94+
}

modules/openapi-generator/src/main/resources/META-INF/services/org.openapitools.codegen.CodegenConfig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ org.openapitools.codegen.languages.PythonFastAPIServerCodegen
118118
org.openapitools.codegen.languages.PythonFlaskConnexionServerCodegen
119119
org.openapitools.codegen.languages.PythonAiohttpConnexionServerCodegen
120120
org.openapitools.codegen.languages.PythonBluePlanetServerCodegen
121+
org.openapitools.codegen.languages.PythonLangchainToolsClientCodegen
121122
org.openapitools.codegen.languages.RClientCodegen
122123
org.openapitools.codegen.languages.RubyClientCodegen
123124
org.openapitools.codegen.languages.RubyOnRailsServerCodegen

modules/openapi-generator/src/main/resources/python-langchain-tools/README.mustache

Whitespace-only changes.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
An aggregator for all generated LangChain tools.
4+
5+
This file provides a single list, `all_tools`, that you can import and
6+
provide to your LangChain agent.
7+
"""
8+
from typing import List
9+
from langchain.tools import StructuredTool
10+
11+
from {{packageName}}.api_client import ApiClient
12+
from {{packageName}}.configuration import Configuration
13+
14+
# Import the factory functions from each tool module
15+
{{#apiInfo}}
16+
{{#apis}}
17+
from {{packageName}}.tools.{{classFilename}}_tools import get_{{classVarName}}_tools
18+
{{/apis}}
19+
{{/apiInfo}}
20+
21+
def get_all_tools() -> List[StructuredTool]:
22+
"""
23+
Initializes the API client and aggregates tools from all API modules.
24+
"""
25+
# TODO: You may need to customize this based on your API's authentication needs.
26+
configuration = Configuration(host="{{{serverUrl}}}")
27+
28+
all_tools_list = []
29+
30+
with ApiClient(configuration) as api_client:
31+
{{#apiInfo}}
32+
{{#apis}}
33+
all_tools_list.extend(get_{{classVarName}}_tools(api_client))
34+
{{/apis}}
35+
{{/apiInfo}}
36+
37+
return all_tools_list
38+
39+
# A pre-initialized list of all tools for convenience.
40+
all_tools = get_all_tools()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import List
3+
from langchain.tools import StructuredTool
4+
from {{packageName}}.api_client import ApiClient
5+
from {{packageName}}.api.{{classFilename}} import {{classname}}
6+
7+
def get_{{classVarName}}_tools(api_client: ApiClient) -> List[StructuredTool]:
8+
"""A factory function to create and return tools for the {{classname}} API."""
9+
10+
api_instance = {{classname}}(api_client)
11+
tools = []
12+
{{#operations}}
13+
{{#operation}}
14+
15+
{{operationId}}_tool = StructuredTool.from_function(
16+
func=api_instance.{{nickname}},
17+
name="{{operationId}}",
18+
description="{{{summary}}}"
19+
)
20+
tools.append({{operationId}}_tool)
21+
{{/operation}}
22+
{{/operations}}
23+
24+
return tools

modules/openapi-generator/src/main/resources/python-langchain-tools/init.mustache

Whitespace-only changes.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.openapitools.codegen.options;
2+
3+
import org.openapitools.codegen.CodegenConstants;
4+
import org.openapitools.codegen.languages.PythonLangchainToolsClientCodegen;
5+
6+
import com.google.common.collect.ImmutableMap;
7+
8+
import java.util.Map;
9+
10+
public class PythonLangchainToolsClientCodegenOptionsProvider implements OptionsProvider {
11+
public static final String PROJECT_NAME_VALUE = "OpenAPI";
12+
13+
@Override
14+
public String getLanguage() {
15+
return "python-langchain-tools";
16+
}
17+
18+
@Override
19+
public Map<String, String> createOptions() {
20+
ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<String, String>();
21+
return builder
22+
.put(PythonLangchainToolsClientCodegen.PROJECT_NAME, PROJECT_NAME_VALUE)
23+
.build();
24+
}
25+
26+
@Override
27+
public boolean isServer() {
28+
return false;
29+
}
30+
}
31+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package org.openapitools.codegen.python.langchain.tools;
2+
3+
import org.openapitools.codegen.*;
4+
import org.openapitools.codegen.languages.PythonLangchainToolsClientCodegen;
5+
import io.swagger.models.*;
6+
import io.swagger.models.properties.*;
7+
8+
import org.testng.Assert;
9+
import org.testng.annotations.Test;
10+
11+
@SuppressWarnings("static-method")
12+
public class PythonLangchainToolsClientCodegenModelTest {
13+
14+
@Test(description = "convert a simple java model")
15+
public void simpleModelTest() {
16+
final Model model = new ModelImpl()
17+
.description("a sample model")
18+
.property("id", new LongProperty())
19+
.property("name", new StringProperty())
20+
.required("id")
21+
.required("name");
22+
final DefaultCodegen codegen = new PythonLangchainToolsClientCodegen();
23+
24+
// TODO: Complete this test.
25+
Assert.fail("Not implemented.");
26+
}
27+
28+
}
29+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.openapitools.codegen.python.langchain.tools;
2+
3+
import org.openapitools.codegen.AbstractOptionsTest;
4+
import org.openapitools.codegen.CodegenConfig;
5+
import org.openapitools.codegen.languages.PythonLangchainToolsClientCodegen;
6+
import org.openapitools.codegen.options.PythonLangchainToolsClientCodegenOptionsProvider;
7+
8+
import static org.mockito.Mockito.mock;
9+
import static org.mockito.Mockito.verify;
10+
11+
public class PythonLangchainToolsClientCodegenOptionsTest extends AbstractOptionsTest {
12+
private PythonLangchainToolsClientCodegen codegen = mock(PythonLangchainToolsClientCodegen.class, mockSettings);
13+
14+
public PythonLangchainToolsClientCodegenOptionsTest() {
15+
super(new PythonLangchainToolsClientCodegenOptionsProvider());
16+
}
17+
18+
@Override
19+
protected CodegenConfig getCodegenConfig() {
20+
return codegen;
21+
}
22+
23+
@SuppressWarnings("unused")
24+
@Override
25+
protected void verifyOptions() {
26+
// TODO: Complete options using Mockito
27+
// verify(codegen).someMethod(arguments)
28+
}
29+
}
30+

0 commit comments

Comments
 (0)