Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lisette/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
'lisette.core.Chat.print_hist': ('core.html#chat.print_hist', 'lisette/core.py'),
'lisette.core._add_cache_control': ('core.html#_add_cache_control', 'lisette/core.py'),
'lisette.core._alite_call_func': ('core.html#_alite_call_func', 'lisette/core.py'),
'lisette.core._apply_cache_idxs': ('core.html#_apply_cache_idxs', 'lisette/core.py'),
'lisette.core._bytes2content': ('core.html#_bytes2content', 'lisette/core.py'),
'lisette.core._extract_tool': ('core.html#_extract_tool', 'lisette/core.py'),
'lisette.core._has_cache': ('core.html#_has_cache', 'lisette/core.py'),
'lisette.core._has_search': ('core.html#_has_search', 'lisette/core.py'),
'lisette.core._lite_call_func': ('core.html#_lite_call_func', 'lisette/core.py'),
'lisette.core._mk_content': ('core.html#_mk_content', 'lisette/core.py'),
'lisette.core._mk_prefill': ('core.html#_mk_prefill', 'lisette/core.py'),
'lisette.core._skip_tools_cache': ('core.html#_skip_tools_cache', 'lisette/core.py'),
'lisette.core._trunc_str': ('core.html#_trunc_str', 'lisette/core.py'),
'lisette.core.adisplay_stream': ('core.html#adisplay_stream', 'lisette/core.py'),
'lisette.core.astream_with_complete': ('core.html#astream_with_complete', 'lisette/core.py'),
Expand Down
15 changes: 6 additions & 9 deletions lisette/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,12 @@ def fmt2hist(outp:str)->list:
return hist

# %% ../nbs/00_core.ipynb
def _skip_tools_cache(msgs, cache_idxs):
"Skip tool use blocks and tool results in and shift cache indices"
res = []
for idx in cache_idxs:
try:
while msgs[idx].get('tool_calls', []) or msgs[idx]['role'] == 'tool': idx -= 1
def _apply_cache_idxs(msgs, cache_idxs=[-1], ttl=None):
'Add cache control to idxs after filtering tools'
ms = L(msgs).filter(lambda m: not (m.get('tool_calls', []) or m['role'] == 'tool'))
for i in cache_idxs:
try: _add_cache_control(ms[i], ttl)
except IndexError: continue
res.append(idx)
return res

# %% ../nbs/00_core.ipynb
def mk_msgs(
Expand All @@ -175,7 +172,7 @@ def mk_msgs(
for m in msgs:
res.append(msg:=remove_cache_ckpts(mk_msg(m, role=role)))
role = 'assistant' if msg['role'] in ('user','function', 'tool') else 'user'
if cache: L(_skip_tools_cache(res, cache_idxs)).map(lambda idx: _add_cache_control(res[idx], ttl))
if cache: _apply_cache_idxs(res, cache_idxs, ttl)
return res

# %% ../nbs/00_core.ipynb
Expand Down
128 changes: 18 additions & 110 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@
{
"data": {
"text/plain": [
"PromptTokensDetailsWrapper(audio_tokens=None, cached_tokens=4144, text_tokens=None, image_tokens=None, cache_creation_tokens=0, cache_creation_token_details=CacheCreationTokenDetails(ephemeral_5m_input_tokens=0, ephemeral_1h_input_tokens=0))"
"PromptTokensDetailsWrapper(audio_tokens=None, cached_tokens=2070, text_tokens=None, image_tokens=None)"
]
},
"execution_count": null,
Expand Down Expand Up @@ -1115,35 +1115,22 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d2a8e351",
"id": "02cb84da",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def _skip_tools_cache(msgs, cache_idxs):\n",
" \"Skip tool use blocks and tool results in and shift cache indices\"\n",
" res = []\n",
" for idx in cache_idxs:\n",
" try: \n",
" while msgs[idx].get('tool_calls', []) or msgs[idx]['role'] == 'tool': idx -= 1\n",
" except IndexError: continue\n",
" res.append(idx)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "632e25d5",
"metadata": {},
"outputs": [],
"source": [
"test_eq(_skip_tools_cache([], [-2]), [])"
"def _apply_cache_idxs(msgs, cache_idxs=[-1], ttl=None):\n",
" 'Add cache control to idxs after filtering tools'\n",
" ms = L(msgs).filter(lambda m: not (m.get('tool_calls', []) or m['role'] == 'tool'))\n",
" for i in cache_idxs:\n",
" try: _add_cache_control(ms[i], ttl)\n",
" except IndexError: continue"
]
},
{
"cell_type": "markdown",
"id": "b8708dfa",
"id": "4ad6deec",
"metadata": {},
"source": [
"Now lets make it easy to provide entire conversations:"
Expand Down Expand Up @@ -1171,21 +1158,10 @@
" for m in msgs:\n",
" res.append(msg:=remove_cache_ckpts(mk_msg(m, role=role)))\n",
" role = 'assistant' if msg['role'] in ('user','function', 'tool') else 'user'\n",
" if cache: L(_skip_tools_cache(res, cache_idxs)).map(lambda idx: _add_cache_control(res[idx], ttl))\n",
" if cache: _apply_cache_idxs(res, cache_idxs, ttl)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b54618d8",
"metadata": {},
"outputs": [],
"source": [
"msgs = mk_msgs(['Hey, call some functions!',fmt_outp,'How are you?','I am great!'])\n",
"test_eq(L(_skip_tools_cache(msgs, [0,1,2,-4,-3,-2,-1])), [0,0,0,-10,-3,-2,-1])"
]
},
{
"cell_type": "markdown",
"id": "6b4624e2",
Expand Down Expand Up @@ -1883,19 +1859,19 @@
"## Using the Power Rule: d/dx(xⁿ) = n·xⁿ⁻¹\n",
"\n",
"**Term by term:**\n",
"- d/dx("
"- d/dx(x³) = 3x²\n",
"- d/dx(2x²) = 4x\n",
"- d/dx(-5x) = -5\n",
"- d/dx(1) = 0\n",
"\n",
"##"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"x³) = 3x²\n",
"- d/dx(2x²) = 4x\n",
"- d/dx(-5x) = -5\n",
"- d/dx(1) = 0\n",
"\n",
"## Answer:\n",
" Answer:\n",
"**f'(x) = 3x² + 4x - 5**"
]
}
Expand Down Expand Up @@ -5102,75 +5078,7 @@
"**Batch 1: Calculate the independent additions**\n",
"Let me start by calculating the two addition operations that don't depend on each other:\n",
"- 10 + 5 (for the numerator)\n",
"- 2 + 1 (for the denominator)\n",
"\n",
"<details class='tool-usage-details'>\n",
"\n",
"```json\n",
"{\n",
" \"id\": \"toolu_017kxrJ18BiKbYNV5vuJ4Lpf\",\n",
" \"call\": {\n",
" \"function\": \"simple_add\",\n",
" \"arguments\": {\n",
" \"a\": \"10\",\n",
" \"b\": \"5\"\n",
" }\n",
" },\n",
" \"result\": \"15\"\n",
"}\n",
"```\n",
"\n",
"</details>\n",
"\n",
"\n",
"\n",
"<details class='tool-usage-details'>\n",
"\n",
"```json\n",
"{\n",
" \"id\": \"toolu_01XnKDfoAU6uqAL1V8CFwUwS\",\n",
" \"call\": {\n",
" \"function\": \"simple_add\",\n",
" \"arguments\": {\n",
" \"a\": \"2\",\n",
" \"b\": \"1\"\n",
" }\n",
" },\n",
" \"result\": \"3\"\n",
"}\n",
"```\n",
"\n",
"</details>\n",
"\n",
"**After Batch 1:** We have:\n",
"- 10 + 5 = 15\n",
"- 2 + 1 = 3\n",
"- So our expression is now: (15 * 3) / 3\n",
"\n",
"**Batch 2: Calculate the multiplication**\n",
"Now I need to multiply 15 * 3 before I can do the final division:\n",
"\n",
"<details class='tool-usage-details'>\n",
"\n",
"```json\n",
"{\n",
" \"id\": \"toolu_01P6iHCianQyB3sfxXquqKik\",\n",
" \"call\": {\n",
" \"function\": \"multiply\",\n",
" \"arguments\": {\n",
" \"a\": \"15\",\n",
" \"b\": \"3\"\n",
" }\n",
" },\n",
" \"result\": \"45\"\n",
"}\n",
"```\n",
"\n",
"</details>\n",
"\n",
"\n",
"\n",
"**"
"- 2 + 1 (for the denominator)"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
Expand Down