Skip to content

Commit 053b278

Browse files
authored
feat(langchain): reinstate OpenAIModerationChain (#9242)
1 parent f88cde5 commit 053b278

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

libs/langchain-classic/src/chains/index.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,8 @@ export {
9999
createOpenAPIChain,
100100
convertOpenAPISpecToOpenAIFunctions,
101101
} from "./openai_functions/openapi.js";
102+
103+
export {
104+
type OpenAIModerationChainInput,
105+
OpenAIModerationChain,
106+
} from "./openai_moderation.js";
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import { type ClientOptions, OpenAIClient } from "@langchain/openai";
2+
import { ChainValues } from "@langchain/core/utils/types";
3+
import {
4+
AsyncCaller,
5+
AsyncCallerParams,
6+
} from "@langchain/core/utils/async_caller";
7+
import { getEnvironmentVariable } from "@langchain/core/utils/env";
8+
import { BaseChain, ChainInputs } from "./base.js";
9+
10+
/**
11+
* Interface for the input parameters of the OpenAIModerationChain class.
12+
*/
13+
export interface OpenAIModerationChainInput
14+
extends ChainInputs,
15+
AsyncCallerParams {
16+
apiKey?: string;
17+
/** @deprecated Use "apiKey" instead. */
18+
openAIApiKey?: string;
19+
openAIOrganization?: string;
20+
throwError?: boolean;
21+
configuration?: ClientOptions;
22+
}
23+
24+
/**
25+
* Class representing a chain for moderating text using the OpenAI
26+
* Moderation API. It extends the BaseChain class and implements the
27+
* OpenAIModerationChainInput interface.
28+
* @example
29+
* ```typescript
30+
* const moderation = new OpenAIModerationChain({ throwError: true });
31+
*
32+
* const badString = "Bad naughty words from user";
33+
*
34+
* try {
35+
* const { output: moderatedContent, results } = await moderation.call({
36+
* input: badString,
37+
* });
38+
*
39+
* if (results[0].category_scores["harassment/threatening"] > 0.01) {
40+
* throw new Error("Harassment detected!");
41+
* }
42+
*
43+
* const model = new OpenAI({ temperature: 0 });
44+
* const promptTemplate = "Hello, how are you today {person}?";
45+
* const prompt = new PromptTemplate({
46+
* template: promptTemplate,
47+
* inputVariables: ["person"],
48+
* });
49+
* const chain = new LLMChain({ llm: model, prompt });
50+
* const response = await chain.call({ person: moderatedContent });
51+
* console.log({ response });
52+
* } catch (error) {
53+
* console.error("Naughty words detected!");
54+
* }
55+
* ```
56+
*/
57+
export class OpenAIModerationChain
58+
extends BaseChain
59+
implements OpenAIModerationChainInput
60+
{
61+
static lc_name() {
62+
return "OpenAIModerationChain";
63+
}
64+
65+
get lc_secrets(): { [key: string]: string } | undefined {
66+
return {
67+
openAIApiKey: "OPENAI_API_KEY",
68+
};
69+
}
70+
71+
inputKey = "input";
72+
73+
outputKey = "output";
74+
75+
openAIApiKey?: string;
76+
77+
openAIOrganization?: string;
78+
79+
clientConfig: ClientOptions;
80+
81+
client: OpenAIClient;
82+
83+
throwError: boolean;
84+
85+
caller: AsyncCaller;
86+
87+
constructor(fields?: OpenAIModerationChainInput) {
88+
super(fields);
89+
this.throwError = fields?.throwError ?? false;
90+
this.openAIApiKey =
91+
fields?.apiKey ??
92+
fields?.openAIApiKey ??
93+
getEnvironmentVariable("OPENAI_API_KEY");
94+
95+
if (!this.openAIApiKey) {
96+
throw new Error("OpenAI API key not found");
97+
}
98+
99+
this.openAIOrganization = fields?.openAIOrganization;
100+
101+
this.clientConfig = {
102+
...fields?.configuration,
103+
apiKey: this.openAIApiKey,
104+
organization: this.openAIOrganization,
105+
};
106+
107+
this.client = new OpenAIClient(this.clientConfig);
108+
109+
this.caller = new AsyncCaller(fields ?? {});
110+
}
111+
112+
_moderate(text: string, results: OpenAIClient.Moderation): string {
113+
if (results.flagged) {
114+
const errorStr = "Text was found that violates OpenAI's content policy.";
115+
if (this.throwError) {
116+
throw new Error(errorStr);
117+
} else {
118+
return errorStr;
119+
}
120+
}
121+
return text;
122+
}
123+
124+
async _call(values: ChainValues): Promise<ChainValues> {
125+
const text = values[this.inputKey];
126+
const moderationRequest: OpenAIClient.ModerationCreateParams = {
127+
input: text,
128+
};
129+
let mod;
130+
try {
131+
mod = await this.caller.call(() =>
132+
this.client.moderations.create(moderationRequest)
133+
);
134+
} catch (error) {
135+
// eslint-disable-next-line no-instanceof/no-instanceof
136+
if (error instanceof Error) {
137+
throw error;
138+
} else {
139+
throw new Error(error as string);
140+
}
141+
}
142+
const output = this._moderate(text, mod.results[0]);
143+
return {
144+
[this.outputKey]: output,
145+
results: mod.results,
146+
};
147+
}
148+
149+
_chainType() {
150+
return "moderation_chain";
151+
}
152+
153+
get inputKeys(): string[] {
154+
return [this.inputKey];
155+
}
156+
157+
get outputKeys(): string[] {
158+
return [this.outputKey];
159+
}
160+
}

0 commit comments

Comments
 (0)