Skip to content

Commit ca88bdd

Browse files
authored
feat: support strong typed tool (#13)
1 parent e1e464f commit ca88bdd

File tree

18 files changed

+2397
-2083
lines changed

18 files changed

+2397
-2083
lines changed

src/core/athena.ts

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,72 @@ import logger from "../utils/logger.js";
55

66
export type Dict<T> = { [key: string]: T };
77

8-
export interface IAthenaArgument {
9-
type: "string" | "number" | "boolean" | "object" | "array";
8+
type IAthenaArgumentPrimitive = {
9+
type: "string" | "number" | "boolean";
1010
desc: string;
1111
required: boolean;
12-
of?: Dict<IAthenaArgument> | IAthenaArgument;
13-
}
12+
};
13+
14+
export type IAthenaArgument =
15+
| IAthenaArgumentPrimitive
16+
| {
17+
type: "object" | "array";
18+
desc: string;
19+
required: boolean;
20+
of?: Dict<IAthenaArgument> | IAthenaArgument;
21+
};
22+
type IAthenaArgumentInstance<T extends IAthenaArgument> =
23+
T extends IAthenaArgumentPrimitive
24+
? T["type"] extends "string"
25+
? T["required"] extends true
26+
? string
27+
: string | undefined
28+
: T["type"] extends "number"
29+
? T["required"] extends true
30+
? number
31+
: number | undefined
32+
: T["type"] extends "boolean"
33+
? T["required"] extends true
34+
? boolean
35+
: boolean | undefined
36+
: never
37+
: T extends { of: Dict<IAthenaArgument> }
38+
? T["required"] extends true
39+
? { [K in keyof T["of"]]: IAthenaArgumentInstance<T["of"][K]> }
40+
:
41+
| { [K in keyof T["of"]]: IAthenaArgumentInstance<T["of"][K]> }
42+
| undefined
43+
: T extends { of: IAthenaArgument }
44+
? T["required"] extends true
45+
? IAthenaArgumentInstance<T["of"]>[]
46+
: IAthenaArgumentInstance<T["of"]>[] | undefined
47+
: T extends { type: "object" }
48+
? T["required"] extends true
49+
? { [K in keyof T["of"]]: any }
50+
: { [K in keyof T["of"]]: any } | undefined
51+
: T extends { type: "array" }
52+
? T["required"] extends true
53+
? any[]
54+
: (any | undefined)[]
55+
: never;
1456

15-
export interface IAthenaTool {
57+
export interface IAthenaTool<
58+
Args extends Dict<IAthenaArgument> = Dict<IAthenaArgument>,
59+
RetArgs extends Dict<IAthenaArgument> = Dict<IAthenaArgument>,
60+
> {
1661
name: string;
1762
desc: string;
18-
args: Dict<IAthenaArgument>;
19-
retvals: Dict<IAthenaArgument>;
20-
fn: (args: Dict<any>) => Promise<Dict<any>>;
63+
args: Args;
64+
retvals: RetArgs;
65+
fn: (args: {
66+
[K in keyof Args]: Args[K] extends IAthenaArgument
67+
? IAthenaArgumentInstance<Args[K]>
68+
: never;
69+
}) => Promise<{
70+
[K in keyof RetArgs]: RetArgs[K] extends IAthenaArgument
71+
? IAthenaArgumentInstance<RetArgs[K]>
72+
: never;
73+
}>;
2174
explain_args?: (args: Dict<any>) => IAthenaExplanation;
2275
explain_retvals?: (args: Dict<any>, retvals: Dict<any>) => IAthenaExplanation;
2376
}
@@ -38,15 +91,15 @@ export class Athena extends EventEmitter {
3891
config: Dict<any>;
3992
states: Dict<Dict<any>>;
4093
plugins: Dict<PluginBase>;
41-
tools: Dict<IAthenaTool>;
94+
tools: Map<string, IAthenaTool<any, any>>;
4295
events: Dict<IAthenaEvent>;
4396

4497
constructor(config: Dict<any>, states: Dict<Dict<any>>) {
4598
super();
4699
this.config = config;
47100
this.states = states;
48101
this.plugins = {};
49-
this.tools = {};
102+
this.tools = new Map();
50103
this.events = {};
51104
}
52105

@@ -103,19 +156,39 @@ export class Athena extends EventEmitter {
103156
logger.warn(`Plugin ${name} is unloaded`);
104157
}
105158

106-
registerTool(tool: IAthenaTool) {
159+
registerTool<
160+
Args extends Dict<IAthenaArgument>,
161+
RetArgs extends Dict<IAthenaArgument>,
162+
Tool extends IAthenaTool<Args, RetArgs>,
163+
>(
164+
config: {
165+
name: string;
166+
desc: string;
167+
args: Args;
168+
retvals: RetArgs;
169+
},
170+
toolImpl: {
171+
fn: Tool["fn"];
172+
explain_args?: Tool["explain_args"];
173+
explain_retvals?: Tool["explain_retvals"];
174+
},
175+
) {
176+
const tool = {
177+
...config,
178+
...toolImpl,
179+
};
107180
if (tool.name in this.tools) {
108181
throw new Error(`Tool ${tool.name} already registered`);
109182
}
110-
this.tools[tool.name] = tool;
183+
this.tools.set(tool.name, tool as unknown as IAthenaTool<any, any>);
111184
logger.warn(`Tool ${tool.name} is registered`);
112185
}
113186

114187
deregisterTool(name: string) {
115188
if (!(name in this.tools)) {
116189
throw new Error(`Tool ${name} not registered`);
117190
}
118-
delete this.tools[name];
191+
this.tools.delete(name);
119192
logger.warn(`Tool ${name} is deregistered`);
120193
}
121194

@@ -159,7 +232,10 @@ export class Athena extends EventEmitter {
159232
if (!(name in this.tools)) {
160233
throw new Error(`Tool ${name} not registered`);
161234
}
162-
const tool = this.tools[name];
235+
const tool = this.tools.get(name);
236+
if (!tool) {
237+
throw new Error(`Tool ${name} not found`);
238+
}
163239
if (tool.explain_args) {
164240
this.emitPrivateEvent("athena/tool-call", tool.explain_args(args));
165241
}

src/plugins/amadeus/init.ts

Lines changed: 101 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -11,108 +11,112 @@ export default class AmadeusPlugin extends PluginBase {
1111
clientId: this.config.client_id,
1212
clientSecret: this.config.client_secret,
1313
});
14-
athena.registerTool({
15-
name: "amadeus/flight-offers-search",
16-
desc: "Return list of Flight Offers based on searching criteria.",
17-
args: {
18-
originLocationCode: {
19-
type: "string",
20-
desc: "city/airport IATA code from which the traveler will depart, e.g. BOS for Boston\nExample : SYD",
21-
required: true,
14+
athena.registerTool(
15+
{
16+
name: "amadeus/flight-offers-search",
17+
desc: "Return list of Flight Offers based on searching criteria.",
18+
args: {
19+
originLocationCode: {
20+
type: "string",
21+
desc: "city/airport IATA code from which the traveler will depart, e.g. BOS for Boston\nExample : SYD",
22+
required: true,
23+
},
24+
destinationLocationCode: {
25+
type: "string",
26+
desc: "city/airport IATA code to which the traveler is going, e.g. PAR for Paris\nExample : BKK",
27+
required: true,
28+
},
29+
departureDate: {
30+
type: "string",
31+
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
32+
required: true,
33+
},
34+
returnDate: {
35+
type: "string",
36+
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
37+
required: false,
38+
},
39+
adults: {
40+
type: "number",
41+
desc: "the number of adult travelers (age 12 or older on date of departure). The total number of seated travelers (adult and children) can not exceed 9.\nDefault value : 1",
42+
required: true,
43+
},
44+
children: {
45+
type: "number",
46+
desc: "the number of child travelers (older than age 2 and younger than age 12 on date of departure) who will each have their own separate seat. If specified, this number should be greater than or equal to 0\nThe total number of seated travelers (adult and children) can not exceed 9.",
47+
required: false,
48+
},
49+
infants: {
50+
type: "number",
51+
desc: "the number of infant travelers (whose age is less or equal to 2 on date of departure). Infants travel on the lap of an adult traveler, and thus the number of infants must not exceed the number of adults. If specified, this number should be greater than or equal to 0",
52+
required: false,
53+
},
54+
travelClass: {
55+
type: "string",
56+
desc: "most of the flight time should be spent in a cabin of this quality or higher. The accepted travel class is economy, premium economy, business or first class. If no travel class is specified, the search considers any travel class\nAvailable values : ECONOMY, PREMIUM_ECONOMY, BUSINESS, FIRST",
57+
required: false,
58+
},
59+
includedAirlineCodes: {
60+
type: "string",
61+
desc: "This option ensures that the system will only consider these airlines. This can not be cumulated with parameter excludedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
62+
required: false,
63+
},
64+
excludedAirlineCodes: {
65+
type: "string",
66+
desc: "This option ensures that the system will ignore these airlines. This can not be cumulated with parameter includedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
67+
required: false,
68+
},
69+
nonStop: {
70+
type: "boolean",
71+
desc: "if set to true, the search will find only flights going from the origin to the destination with no stop in between\nDefault value : false",
72+
required: false,
73+
},
74+
currencyCode: {
75+
type: "string",
76+
desc: "the preferred currency for the flight offers. Currency is specified in the ISO 4217 format, e.g. EUR for Euro",
77+
required: false,
78+
},
79+
maxPrice: {
80+
type: "number",
81+
desc: "maximum price per traveler. By default, no limit is applied. If specified, the value should be a positive number with no decimals",
82+
required: false,
83+
},
84+
max: {
85+
type: "number",
86+
desc: "maximum number of flight offers to return. If specified, the value should be greater than or equal to 1\nDefault value : 250",
87+
required: false,
88+
},
2289
},
23-
destinationLocationCode: {
24-
type: "string",
25-
desc: "city/airport IATA code to which the traveler is going, e.g. PAR for Paris\nExample : BKK",
26-
required: true,
27-
},
28-
departureDate: {
29-
type: "string",
30-
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
31-
required: true,
32-
},
33-
returnDate: {
34-
type: "string",
35-
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
36-
required: false,
37-
},
38-
adults: {
39-
type: "number",
40-
desc: "the number of adult travelers (age 12 or older on date of departure). The total number of seated travelers (adult and children) can not exceed 9.\nDefault value : 1",
41-
required: true,
42-
},
43-
children: {
44-
type: "number",
45-
desc: "the number of child travelers (older than age 2 and younger than age 12 on date of departure) who will each have their own separate seat. If specified, this number should be greater than or equal to 0\nThe total number of seated travelers (adult and children) can not exceed 9.",
46-
required: false,
47-
},
48-
infants: {
49-
type: "number",
50-
desc: "the number of infant travelers (whose age is less or equal to 2 on date of departure). Infants travel on the lap of an adult traveler, and thus the number of infants must not exceed the number of adults. If specified, this number should be greater than or equal to 0",
51-
required: false,
52-
},
53-
travelClass: {
54-
type: "string",
55-
desc: "most of the flight time should be spent in a cabin of this quality or higher. The accepted travel class is economy, premium economy, business or first class. If no travel class is specified, the search considers any travel class\nAvailable values : ECONOMY, PREMIUM_ECONOMY, BUSINESS, FIRST",
56-
required: false,
57-
},
58-
includedAirlineCodes: {
59-
type: "string",
60-
desc: "This option ensures that the system will only consider these airlines. This can not be cumulated with parameter excludedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
61-
required: false,
62-
},
63-
excludedAirlineCodes: {
64-
type: "string",
65-
desc: "This option ensures that the system will ignore these airlines. This can not be cumulated with parameter includedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
66-
required: false,
67-
},
68-
nonStop: {
69-
type: "boolean",
70-
desc: "if set to true, the search will find only flights going from the origin to the destination with no stop in between\nDefault value : false",
71-
required: false,
72-
},
73-
currencyCode: {
74-
type: "string",
75-
desc: "the preferred currency for the flight offers. Currency is specified in the ISO 4217 format, e.g. EUR for Euro",
76-
required: false,
77-
},
78-
maxPrice: {
79-
type: "number",
80-
desc: "maximum price per traveler. By default, no limit is applied. If specified, the value should be a positive number with no decimals",
81-
required: false,
82-
},
83-
max: {
84-
type: "number",
85-
desc: "maximum number of flight offers to return. If specified, the value should be greater than or equal to 1\nDefault value : 250",
86-
required: false,
90+
retvals: {
91+
data: {
92+
desc: "The flight offers",
93+
type: "object",
94+
required: true,
95+
},
8796
},
8897
},
89-
retvals: {
90-
data: {
91-
desc: "The flight offers",
92-
type: "object",
93-
required: true,
98+
{
99+
fn: async (args) => {
100+
const response = await this.amadeus.shopping.flightOffersSearch.get({
101+
originLocationCode: args.originLocationCode,
102+
destinationLocationCode: args.destinationLocationCode,
103+
departureDate: args.departureDate,
104+
returnDate: args.returnDate,
105+
adults: args.adults,
106+
children: args.children,
107+
infants: args.infants,
108+
travelClass: args.travelClass,
109+
includedAirlineCodes: args.includedAirlineCodes,
110+
excludedAirlineCodes: args.excludedAirlineCodes,
111+
nonStop: args.nonStop,
112+
currencyCode: args.currencyCode,
113+
maxPrice: args.maxPrice,
114+
max: args.max,
115+
});
116+
return { data: response.data };
94117
},
95118
},
96-
fn: async (args: Dict<any>) => {
97-
const response = await this.amadeus.shopping.flightOffersSearch.get({
98-
originLocationCode: args.originLocationCode,
99-
destinationLocationCode: args.destinationLocationCode,
100-
departureDate: args.departureDate,
101-
returnDate: args.returnDate,
102-
adults: args.adults,
103-
children: args.children,
104-
infants: args.infants,
105-
travelClass: args.travelClass,
106-
includedAirlineCodes: args.includedAirlineCodes,
107-
excludedAirlineCodes: args.excludedAirlineCodes,
108-
nonStop: args.nonStop,
109-
currencyCode: args.currencyCode,
110-
maxPrice: args.maxPrice,
111-
max: args.max,
112-
});
113-
return { data: response.data };
114-
},
115-
});
119+
);
116120
}
117121

118122
async unload(athena: Athena) {

0 commit comments

Comments
 (0)