Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions src/core/athena.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,72 @@ import logger from "../utils/logger.js";

export type Dict<T> = { [key: string]: T };

export interface IAthenaArgument {
type: "string" | "number" | "boolean" | "object" | "array";
type IAthenaArgumentPrimitive = {
type: "string" | "number" | "boolean";
desc: string;
required: boolean;
of?: Dict<IAthenaArgument> | IAthenaArgument;
}
};

export type IAthenaArgument =
| IAthenaArgumentPrimitive
| {
type: "object" | "array";
desc: string;
required: boolean;
of?: Dict<IAthenaArgument> | IAthenaArgument;
};
type IAthenaArgumentInstance<T extends IAthenaArgument> =
T extends IAthenaArgumentPrimitive
? T["type"] extends "string"
? T["required"] extends true
? string
: string | undefined
: T["type"] extends "number"
? T["required"] extends true
? number
: number | undefined
: T["type"] extends "boolean"
? T["required"] extends true
? boolean
: boolean | undefined
: never
: T extends { of: Dict<IAthenaArgument> }
? T["required"] extends true
? { [K in keyof T["of"]]: IAthenaArgumentInstance<T["of"][K]> }
:
| { [K in keyof T["of"]]: IAthenaArgumentInstance<T["of"][K]> }
| undefined
: T extends { of: IAthenaArgument }
? T["required"] extends true
? IAthenaArgumentInstance<T["of"]>[]
: IAthenaArgumentInstance<T["of"]>[] | undefined
: T extends { type: "object" }
? T["required"] extends true
? { [K in keyof T["of"]]: any }
: { [K in keyof T["of"]]: any } | undefined
: T extends { type: "array" }
? T["required"] extends true
? any[]
: (any | undefined)[]
: never;

export interface IAthenaTool {
export interface IAthenaTool<
Args extends Dict<IAthenaArgument> = Dict<IAthenaArgument>,
RetArgs extends Dict<IAthenaArgument> = Dict<IAthenaArgument>,
> {
name: string;
desc: string;
args: Dict<IAthenaArgument>;
retvals: Dict<IAthenaArgument>;
fn: (args: Dict<any>) => Promise<Dict<any>>;
args: Args;
retvals: RetArgs;
fn: (args: {
[K in keyof Args]: Args[K] extends IAthenaArgument
? IAthenaArgumentInstance<Args[K]>
: never;
}) => Promise<{
[K in keyof RetArgs]: RetArgs[K] extends IAthenaArgument
? IAthenaArgumentInstance<RetArgs[K]>
: never;
}>;
explain_args?: (args: Dict<any>) => IAthenaExplanation;
explain_retvals?: (args: Dict<any>, retvals: Dict<any>) => IAthenaExplanation;
}
Expand All @@ -38,15 +91,15 @@ export class Athena extends EventEmitter {
config: Dict<any>;
states: Dict<Dict<any>>;
plugins: Dict<PluginBase>;
tools: Dict<IAthenaTool>;
tools: Map<string, IAthenaTool<any, any>>;
events: Dict<IAthenaEvent>;

constructor(config: Dict<any>, states: Dict<Dict<any>>) {
super();
this.config = config;
this.states = states;
this.plugins = {};
this.tools = {};
this.tools = new Map();
this.events = {};
}

Expand Down Expand Up @@ -103,19 +156,39 @@ export class Athena extends EventEmitter {
logger.warn(`Plugin ${name} is unloaded`);
}

registerTool(tool: IAthenaTool) {
registerTool<
Args extends Dict<IAthenaArgument>,
RetArgs extends Dict<IAthenaArgument>,
Tool extends IAthenaTool<Args, RetArgs>,
>(
config: {
name: string;
desc: string;
args: Args;
retvals: RetArgs;
},
toolImpl: {
fn: Tool["fn"];
explain_args?: Tool["explain_args"];
explain_retvals?: Tool["explain_retvals"];
},
) {
const tool = {
...config,
...toolImpl,
};
if (tool.name in this.tools) {
throw new Error(`Tool ${tool.name} already registered`);
}
this.tools[tool.name] = tool;
this.tools.set(tool.name, tool as unknown as IAthenaTool<any, any>);
logger.warn(`Tool ${tool.name} is registered`);
}

deregisterTool(name: string) {
if (!(name in this.tools)) {
throw new Error(`Tool ${name} not registered`);
}
delete this.tools[name];
this.tools.delete(name);
logger.warn(`Tool ${name} is deregistered`);
}

Expand Down Expand Up @@ -159,7 +232,10 @@ export class Athena extends EventEmitter {
if (!(name in this.tools)) {
throw new Error(`Tool ${name} not registered`);
}
const tool = this.tools[name];
const tool = this.tools.get(name);
if (!tool) {
throw new Error(`Tool ${name} not found`);
}
if (tool.explain_args) {
this.emitPrivateEvent("athena/tool-call", tool.explain_args(args));
}
Expand Down
198 changes: 101 additions & 97 deletions src/plugins/amadeus/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,108 +11,112 @@ export default class AmadeusPlugin extends PluginBase {
clientId: this.config.client_id,
clientSecret: this.config.client_secret,
});
athena.registerTool({
name: "amadeus/flight-offers-search",
desc: "Return list of Flight Offers based on searching criteria.",
args: {
originLocationCode: {
type: "string",
desc: "city/airport IATA code from which the traveler will depart, e.g. BOS for Boston\nExample : SYD",
required: true,
athena.registerTool(
{
name: "amadeus/flight-offers-search",
desc: "Return list of Flight Offers based on searching criteria.",
args: {
originLocationCode: {
type: "string",
desc: "city/airport IATA code from which the traveler will depart, e.g. BOS for Boston\nExample : SYD",
required: true,
},
destinationLocationCode: {
type: "string",
desc: "city/airport IATA code to which the traveler is going, e.g. PAR for Paris\nExample : BKK",
required: true,
},
departureDate: {
type: "string",
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
required: true,
},
returnDate: {
type: "string",
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
required: false,
},
adults: {
type: "number",
desc: "the number of adult travelers (age 12 or older on date of departure). The total number of seated travelers (adult and children) can not exceed 9.\nDefault value : 1",
required: true,
},
children: {
type: "number",
desc: "the number of child travelers (older than age 2 and younger than age 12 on date of departure) who will each have their own separate seat. If specified, this number should be greater than or equal to 0\nThe total number of seated travelers (adult and children) can not exceed 9.",
required: false,
},
infants: {
type: "number",
desc: "the number of infant travelers (whose age is less or equal to 2 on date of departure). Infants travel on the lap of an adult traveler, and thus the number of infants must not exceed the number of adults. If specified, this number should be greater than or equal to 0",
required: false,
},
travelClass: {
type: "string",
desc: "most of the flight time should be spent in a cabin of this quality or higher. The accepted travel class is economy, premium economy, business or first class. If no travel class is specified, the search considers any travel class\nAvailable values : ECONOMY, PREMIUM_ECONOMY, BUSINESS, FIRST",
required: false,
},
includedAirlineCodes: {
type: "string",
desc: "This option ensures that the system will only consider these airlines. This can not be cumulated with parameter excludedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
required: false,
},
excludedAirlineCodes: {
type: "string",
desc: "This option ensures that the system will ignore these airlines. This can not be cumulated with parameter includedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
required: false,
},
nonStop: {
type: "boolean",
desc: "if set to true, the search will find only flights going from the origin to the destination with no stop in between\nDefault value : false",
required: false,
},
currencyCode: {
type: "string",
desc: "the preferred currency for the flight offers. Currency is specified in the ISO 4217 format, e.g. EUR for Euro",
required: false,
},
maxPrice: {
type: "number",
desc: "maximum price per traveler. By default, no limit is applied. If specified, the value should be a positive number with no decimals",
required: false,
},
max: {
type: "number",
desc: "maximum number of flight offers to return. If specified, the value should be greater than or equal to 1\nDefault value : 250",
required: false,
},
},
destinationLocationCode: {
type: "string",
desc: "city/airport IATA code to which the traveler is going, e.g. PAR for Paris\nExample : BKK",
required: true,
},
departureDate: {
type: "string",
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
required: true,
},
returnDate: {
type: "string",
desc: "the date on which the traveler will depart from the origin to go to the destination. Dates are specified in the ISO 8601 YYYY-MM-DD format, e.g. 2017-12-25\nExample : 2023-05-02",
required: false,
},
adults: {
type: "number",
desc: "the number of adult travelers (age 12 or older on date of departure). The total number of seated travelers (adult and children) can not exceed 9.\nDefault value : 1",
required: true,
},
children: {
type: "number",
desc: "the number of child travelers (older than age 2 and younger than age 12 on date of departure) who will each have their own separate seat. If specified, this number should be greater than or equal to 0\nThe total number of seated travelers (adult and children) can not exceed 9.",
required: false,
},
infants: {
type: "number",
desc: "the number of infant travelers (whose age is less or equal to 2 on date of departure). Infants travel on the lap of an adult traveler, and thus the number of infants must not exceed the number of adults. If specified, this number should be greater than or equal to 0",
required: false,
},
travelClass: {
type: "string",
desc: "most of the flight time should be spent in a cabin of this quality or higher. The accepted travel class is economy, premium economy, business or first class. If no travel class is specified, the search considers any travel class\nAvailable values : ECONOMY, PREMIUM_ECONOMY, BUSINESS, FIRST",
required: false,
},
includedAirlineCodes: {
type: "string",
desc: "This option ensures that the system will only consider these airlines. This can not be cumulated with parameter excludedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
required: false,
},
excludedAirlineCodes: {
type: "string",
desc: "This option ensures that the system will ignore these airlines. This can not be cumulated with parameter includedAirlineCodes.\nAirlines are specified as IATA airline codes and are comma-separated, e.g. 6X,7X,8X",
required: false,
},
nonStop: {
type: "boolean",
desc: "if set to true, the search will find only flights going from the origin to the destination with no stop in between\nDefault value : false",
required: false,
},
currencyCode: {
type: "string",
desc: "the preferred currency for the flight offers. Currency is specified in the ISO 4217 format, e.g. EUR for Euro",
required: false,
},
maxPrice: {
type: "number",
desc: "maximum price per traveler. By default, no limit is applied. If specified, the value should be a positive number with no decimals",
required: false,
},
max: {
type: "number",
desc: "maximum number of flight offers to return. If specified, the value should be greater than or equal to 1\nDefault value : 250",
required: false,
retvals: {
data: {
desc: "The flight offers",
type: "object",
required: true,
},
},
},
retvals: {
data: {
desc: "The flight offers",
type: "object",
required: true,
{
fn: async (args) => {
const response = await this.amadeus.shopping.flightOffersSearch.get({
originLocationCode: args.originLocationCode,
destinationLocationCode: args.destinationLocationCode,
departureDate: args.departureDate,
returnDate: args.returnDate,
adults: args.adults,
children: args.children,
infants: args.infants,
travelClass: args.travelClass,
includedAirlineCodes: args.includedAirlineCodes,
excludedAirlineCodes: args.excludedAirlineCodes,
nonStop: args.nonStop,
currencyCode: args.currencyCode,
maxPrice: args.maxPrice,
max: args.max,
});
return { data: response.data };
},
},
fn: async (args: Dict<any>) => {
const response = await this.amadeus.shopping.flightOffersSearch.get({
originLocationCode: args.originLocationCode,
destinationLocationCode: args.destinationLocationCode,
departureDate: args.departureDate,
returnDate: args.returnDate,
adults: args.adults,
children: args.children,
infants: args.infants,
travelClass: args.travelClass,
includedAirlineCodes: args.includedAirlineCodes,
excludedAirlineCodes: args.excludedAirlineCodes,
nonStop: args.nonStop,
currencyCode: args.currencyCode,
maxPrice: args.maxPrice,
max: args.max,
});
return { data: response.data };
},
});
);
}

async unload(athena: Athena) {
Expand Down
Loading