Skip to content

Commit a144498

Browse files
committed
feat(Jinja): support rejectattr
1 parent 9cb9c0d commit a144498

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

packages/jinja/src/runtime.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,40 @@ export class Interpreter {
647647

648648
return new ArrayValue(filtered);
649649
}
650+
case "rejectattr": {
651+
if (operand.value.some((x) => !(x instanceof ObjectValue))) {
652+
throw new Error("`rejectattr` can only be applied to array of objects");
653+
}
654+
if (filter.args.some((x) => x.type !== "StringLiteral")) {
655+
throw new Error("arguments of `rejectattr` must be strings");
656+
}
657+
658+
const [attr, testName, value] = filter.args.map((x) => this.evaluate(x, environment)) as StringValue[];
659+
660+
let testFunction: (...x: AnyRuntimeValue[]) => boolean;
661+
if (testName) {
662+
// Get the test function from the environment
663+
const test = environment.tests.get(testName.value);
664+
if (!test) {
665+
throw new Error(`Unknown test: ${testName.value}`);
666+
}
667+
testFunction = test;
668+
} else {
669+
// Default to truthiness of first argument
670+
testFunction = (...x: AnyRuntimeValue[]) => x[0].__bool__().value;
671+
}
672+
673+
// Filter the array using the test function
674+
const filtered = (operand.value as ObjectValue[]).filter((item) => {
675+
const a = item.value.get(attr.value);
676+
if (a) {
677+
return !testFunction(a, value);
678+
}
679+
return true;
680+
});
681+
682+
return new ArrayValue(filtered);
683+
}
650684
case "map": {
651685
// Accumulate kwargs
652686
const [, kwargs] = this.evaluateArguments(filter.args, environment);

packages/jinja/test/e2e.test.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,15 @@ const TEST_CUSTOM_TEMPLATES = Object.freeze({
630630
},
631631
target: `<|begin_of_text|>You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\n Args:\n symbol(str): The stock symbol.\n Returns:\n A dictionary containing fundamental data.\n\nKeys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol."}}, "required": ["symbol"]}} </tools>Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\nFor each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|><|im_start|>user\nFetch the stock fundamentals data for Tesla (TSLA)<|im_end|>\n<|im_start|>assistant\n`,
632632
},
633+
"mistralai/Mistral-Nemo-Instruct-2407": {
634+
chat_template: `{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}\n {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS][" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{"type": "function", "function": {' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- '"' + key + '": "' + val + '"' }}\n {%- else %}\n {{- '"' + key + '": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST]" + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- "[TOOL_CALLS][" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- ', "id": "' + tool_call.id + '"}' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{"content": ' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n`,
635+
data: {
636+
messages: EXAMPLE_CHAT,
637+
bos_token: "<s>",
638+
eos_token: "</s>"
639+
},
640+
target: `<s>[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?</s>[INST]I'd like to show off how chat templating works![/INST]`,
641+
},
633642
});
634643

635644
describe("End-to-end tests", () => {

packages/jinja/test/templates.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ const TEST_STRINGS = {
8787
FILTER_OPERATOR_8: `{{ obj | tojson(indent=2) }}`,
8888
FILTER_OPERATOR_9: `{{ data | map(attribute='val') | list | tojson }}`,
8989
FILTER_OPERATOR_10: `|{{ " 1 \n 2 \n 3 \n\n " | indent }}|{{ " 1 \n 2 \n 3 \n\n " | indent(2) }}|{{ " 1 \n 2 \n 3 \n\n " | indent(first=True) }}|{{ " 1 \n 2 \n 3 \n\n " | indent(blank=True) }}|{{ " 1 \n 2 \n 3 \n\n " | indent(4, first=True) }}|`,
90+
FILTER_OPERATOR_11: `{{ items | rejectattr('key') | length }}`,
91+
FILTER_OPERATOR_12: `{{ messages | rejectattr('role', 'equalto', 'system') | length }}`,
9092

9193
// Logical operators between non-Booleans
9294
BOOLEAN_NUMERICAL: `|{{ 1 and 2 }}|{{ 1 and 0 }}|{{ 0 and 1 }}|{{ 0 and 0 }}|{{ 1 or 2 }}|{{ 1 or 0 }}|{{ 0 or 1 }}|{{ 0 or 0 }}|{{ not 1 }}|{{ not 0 }}|`,
@@ -1624,6 +1626,34 @@ const TEST_PARSED = {
16241626
{ value: "}}", type: "CloseExpression" },
16251627
{ value: "|", type: "Text" },
16261628
],
1629+
FILTER_OPERATOR_11: [
1630+
{ value: "{{", type: "OpenExpression" },
1631+
{ value: "items", type: "Identifier" },
1632+
{ value: "|", type: "Pipe" },
1633+
{ value: "rejectattr", type: "Identifier" },
1634+
{ value: "(", type: "OpenParen" },
1635+
{ value: "key", type: "StringLiteral" },
1636+
{ value: ")", type: "CloseParen" },
1637+
{ value: "|", type: "Pipe" },
1638+
{ value: "length", type: "Identifier" },
1639+
{ value: "}}", type: "CloseExpression" },
1640+
],
1641+
FILTER_OPERATOR_12: [
1642+
{ value: "{{", type: "OpenExpression" },
1643+
{ value: "messages", type: "Identifier" },
1644+
{ value: "|", type: "Pipe" },
1645+
{ value: "rejectattr", type: "Identifier" },
1646+
{ value: "(", type: "OpenParen" },
1647+
{ value: "role", type: "StringLiteral" },
1648+
{ value: ",", type: "Comma" },
1649+
{ value: "equalto", type: "StringLiteral" },
1650+
{ value: ",", type: "Comma" },
1651+
{ value: "system", type: "StringLiteral" },
1652+
{ value: ")", type: "CloseParen" },
1653+
{ value: "|", type: "Pipe" },
1654+
{ value: "length", type: "Identifier" },
1655+
{ value: "}}", type: "CloseExpression" },
1656+
],
16271657

16281658
// Logical operators between non-Booleans
16291659
BOOLEAN_NUMERICAL: [
@@ -2909,6 +2939,12 @@ const TEST_CONTEXT = {
29092939
data: [{ val: 1 }, { val: 2 }, { val: 3 }],
29102940
},
29112941
FILTER_OPERATOR_10: {},
2942+
FILTER_OPERATOR_11: {
2943+
items: [{ key: "a" }, { key: 0 }, { key: 1 }, {}, { key: false }],
2944+
},
2945+
FILTER_OPERATOR_12: {
2946+
messages: [{ role: "system" }, { role: "user" }, { role: "assistant" }],
2947+
},
29122948

29132949
// Logical operators between non-Booleans
29142950
BOOLEAN_NUMERICAL: {},
@@ -3073,6 +3109,8 @@ const EXPECTED_OUTPUTS = {
30733109
FILTER_OPERATOR_8: `{\n "a": [\n 1,\n 2,\n 3\n ],\n "b": 1,\n "c": {\n "d": 2,\n "e": {\n "f": 3,\n "g": {\n "h": 4,\n "i": [\n 1,\n 2,\n 3\n ]\n }\n }\n }\n}`,
30743110
FILTER_OPERATOR_9: `[1, 2, 3]`,
30753111
FILTER_OPERATOR_10: `| 1 \n 2 \n 3 \n\n | 1 \n 2 \n 3 \n\n | 1 \n 2 \n 3 \n\n | 1 \n 2 \n 3 \n \n | 1 \n 2 \n 3 \n\n |`,
3112+
FILTER_OPERATOR_11: `3`,
3113+
FILTER_OPERATOR_12: `2`,
30763114

30773115
// Logical operators between non-Booleans
30783116
BOOLEAN_NUMERICAL: `|2|0|0|0|1|1|1|0|false|true|`,

0 commit comments

Comments
 (0)