@@ -237,11 +237,18 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
237237# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
238238LLAMA_FILE_MAGIC_GGSN = 0x6767736E
239239
240+ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
241+ LLAMA_FILE_MAGIC_GGSQ = 0x67677371
242+
240243# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
241244LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
242245# define LLAMA_SESSION_VERSION 5
243246LLAMA_SESSION_VERSION = 5
244247
248+ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
249+ LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
250+ #define LLAMA_STATE_SEQ_VERSION 1
251+ LLAMA_STATE_SEQ_VERSION = 1
245252
246253# struct llama_model;
247254llama_model_p = NewType ("llama_model_p" , int )
@@ -1467,6 +1474,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
14671474
14681475
14691476# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
1477+ # // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
14701478# // seq_id < 0 : match any sequence
14711479# // p0 < 0 : [0, p1]
14721480# // p1 < 0 : [p0, inf)
@@ -1493,6 +1501,9 @@ def llama_kv_cache_seq_rm(
14931501 / ,
14941502) -> bool :
14951503 """Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
1504+
1505+ Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
1506+
14961507 seq_id < 0 : match any sequence
14971508 p0 < 0 : [0, p1]
14981509 p1 < 0 : [p0, inf)"""
@@ -1652,7 +1663,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
16521663
16531664# Returns the maximum size in bytes of the state (rng, logits, embedding
16541665# and kv_cache) - will often be smaller after compacting tokens
1655- # LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
1666+ # LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
1667+ @ctypes_function ("llama_state_get_size" , [llama_context_p_ctypes ], ctypes .c_size_t )
1668+ def llama_state_get_size (ctx : llama_context_p , / ) -> int :
1669+ """Returns the maximum size in bytes of the state (rng, logits, embedding
1670+ and kv_cache) - will often be smaller after compacting tokens"""
1671+ ...
1672+
1673+
1674+ # LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
1675+ # "use llama_state_get_size instead");
16561676@ctypes_function ("llama_get_state_size" , [llama_context_p_ctypes ], ctypes .c_size_t )
16571677def llama_get_state_size (ctx : llama_context_p , / ) -> int :
16581678 """Returns the maximum size in bytes of the state (rng, logits, embedding
@@ -1663,9 +1683,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
16631683# Copies the state to the specified destination address.
16641684# Destination needs to have allocated enough memory.
16651685# Returns the number of bytes copied
1666- # LLAMA_API size_t llama_copy_state_data (
1686+ # LLAMA_API size_t llama_state_get_data (
16671687# struct llama_context * ctx,
16681688# uint8_t * dst);
1689+ @ctypes_function (
1690+ "llama_state_get_data" ,
1691+ [
1692+ llama_context_p_ctypes ,
1693+ ctypes .POINTER (ctypes .c_uint8 ),
1694+ ],
1695+ ctypes .c_size_t ,
1696+ )
1697+ def llama_state_get_data (
1698+ ctx : llama_context_p , dst : CtypesArray [ctypes .c_uint8 ], /
1699+ ) -> int :
1700+ """Copies the state to the specified destination address.
1701+ Destination needs to have allocated enough memory.
1702+ Returns the number of bytes copied"""
1703+ ...
1704+
1705+
1706+ # LLAMA_API DEPRECATED(size_t llama_copy_state_data(
1707+ # struct llama_context * ctx,
1708+ # uint8_t * dst),
1709+ # "use llama_state_get_data instead");
16691710@ctypes_function (
16701711 "llama_copy_state_data" ,
16711712 [
@@ -1685,9 +1726,26 @@ def llama_copy_state_data(
16851726
16861727# // Set the state reading from the specified address
16871728# // Returns the number of bytes read
1688- # LLAMA_API size_t llama_set_state_data (
1729+ # LLAMA_API size_t llama_state_set_data (
16891730# struct llama_context * ctx,
16901731# const uint8_t * src);
1732+ @ctypes_function (
1733+ "llama_state_set_data" ,
1734+ [llama_context_p_ctypes , ctypes .POINTER (ctypes .c_uint8 )],
1735+ ctypes .c_size_t ,
1736+ )
1737+ def llama_state_set_data (
1738+ ctx : llama_context_p , src : CtypesArray [ctypes .c_uint8 ], /
1739+ ) -> int :
1740+ """Set the state reading from the specified address
1741+ Returns the number of bytes read"""
1742+ ...
1743+
1744+
1745+ # LLAMA_API DEPRECATED(size_t llama_set_state_data(
1746+ # struct llama_context * ctx,
1747+ # const uint8_t * src),
1748+ # "use llama_state_set_data instead");
16911749@ctypes_function (
16921750 "llama_set_state_data" ,
16931751 [llama_context_p_ctypes , ctypes .POINTER (ctypes .c_uint8 )],
@@ -1701,12 +1759,40 @@ def llama_set_state_data(
17011759
17021760
17031761# Save/load session file
1704- # LLAMA_API bool llama_load_session_file (
1762+ # LLAMA_API bool llama_state_load_file (
17051763# struct llama_context * ctx,
17061764# const char * path_session,
17071765# llama_token * tokens_out,
17081766# size_t n_token_capacity,
17091767# size_t * n_token_count_out);
1768+ @ctypes_function (
1769+ "llama_state_load_file" ,
1770+ [
1771+ llama_context_p_ctypes ,
1772+ ctypes .c_char_p ,
1773+ llama_token_p ,
1774+ ctypes .c_size_t ,
1775+ ctypes .POINTER (ctypes .c_size_t ),
1776+ ],
1777+ ctypes .c_bool ,
1778+ )
1779+ def llama_state_load_file (
1780+ ctx : llama_context_p ,
1781+ path_session : bytes ,
1782+ tokens_out : CtypesArray [llama_token ],
1783+ n_token_capacity : Union [ctypes .c_size_t , int ],
1784+ n_token_count_out : CtypesPointerOrRef [ctypes .c_size_t ],
1785+ / ,
1786+ ) -> bool : ...
1787+
1788+
1789+ # LLAMA_API DEPRECATED(bool llama_load_session_file(
1790+ # struct llama_context * ctx,
1791+ # const char * path_session,
1792+ # llama_token * tokens_out,
1793+ # size_t n_token_capacity,
1794+ # size_t * n_token_count_out),
1795+ # "use llama_state_load_file instead");
17101796@ctypes_function (
17111797 "llama_load_session_file" ,
17121798 [
@@ -1728,11 +1814,36 @@ def llama_load_session_file(
17281814) -> int : ...
17291815
17301816
1731- # LLAMA_API bool llama_save_session_file (
1817+ # LLAMA_API bool llama_state_save_file (
17321818# struct llama_context * ctx,
17331819# const char * path_session,
17341820# const llama_token * tokens,
17351821# size_t n_token_count);
1822+ @ctypes_function (
1823+ "llama_state_save_file" ,
1824+ [
1825+ llama_context_p_ctypes ,
1826+ ctypes .c_char_p ,
1827+ llama_token_p ,
1828+ ctypes .c_size_t ,
1829+ ],
1830+ ctypes .c_bool ,
1831+ )
1832+ def llama_state_save_file (
1833+ ctx : llama_context_p ,
1834+ path_session : bytes ,
1835+ tokens : CtypesArray [llama_token ],
1836+ n_token_count : Union [ctypes .c_size_t , int ],
1837+ / ,
1838+ ) -> bool : ...
1839+
1840+
1841+ # LLAMA_API DEPRECATED(bool llama_save_session_file(
1842+ # struct llama_context * ctx,
1843+ # const char * path_session,
1844+ # const llama_token * tokens,
1845+ # size_t n_token_count),
1846+ # "use llama_state_save_file instead");
17361847@ctypes_function (
17371848 "llama_save_session_file" ,
17381849 [
@@ -1752,6 +1863,116 @@ def llama_save_session_file(
17521863) -> int : ...
17531864
17541865
1866+ # // Get the exact size needed to copy the KV cache of a single sequence
1867+ # LLAMA_API size_t llama_state_seq_get_size(
1868+ # struct llama_context * ctx,
1869+ # llama_seq_id seq_id);
1870+ @ctypes_function (
1871+ "llama_state_seq_get_size" ,
1872+ [llama_context_p_ctypes , llama_seq_id ],
1873+ ctypes .c_size_t ,
1874+ )
1875+ def llama_state_seq_get_size (ctx : llama_context_p , seq_id : llama_seq_id , / ) -> int :
1876+ """Get the exact size needed to copy the KV cache of a single sequence"""
1877+ ...
1878+
1879+
1880+ # // Copy the KV cache of a single sequence into the specified buffer
1881+ # LLAMA_API size_t llama_state_seq_get_data(
1882+ # struct llama_context * ctx,
1883+ # uint8_t * dst,
1884+ # llama_seq_id seq_id);
1885+ @ctypes_function (
1886+ "llama_state_seq_get_data" ,
1887+ [llama_context_p_ctypes , ctypes .POINTER (ctypes .c_uint8 ), llama_seq_id ],
1888+ ctypes .c_size_t ,
1889+ )
1890+ def llama_state_seq_get_data (
1891+ ctx : llama_context_p , dst : CtypesArray [ctypes .c_uint8 ], seq_id : llama_seq_id , /
1892+ ) -> int :
1893+ """Copy the KV cache of a single sequence into the specified buffer"""
1894+ ...
1895+
1896+
1897+ # // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
1898+ # // Returns:
1899+ # // - Positive: Ok
1900+ # // - Zero: Failed to load
1901+ # LLAMA_API size_t llama_state_seq_set_data(
1902+ # struct llama_context * ctx,
1903+ # const uint8_t * src,
1904+ # llama_seq_id dest_seq_id);
1905+ @ctypes_function (
1906+ "llama_state_seq_set_data" ,
1907+ [llama_context_p_ctypes , ctypes .POINTER (ctypes .c_uint8 ), llama_seq_id ],
1908+ ctypes .c_size_t ,
1909+ )
1910+ def llama_state_seq_set_data (
1911+ ctx : llama_context_p , src : CtypesArray [ctypes .c_uint8 ], dest_seq_id : llama_seq_id , /
1912+ ) -> int :
1913+ """Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
1914+ ...
1915+
1916+
1917+ # LLAMA_API size_t llama_state_seq_save_file(
1918+ # struct llama_context * ctx,
1919+ # const char * filepath,
1920+ # llama_seq_id seq_id,
1921+ # const llama_token * tokens,
1922+ # size_t n_token_count);
1923+ @ctypes_function (
1924+ "llama_state_seq_save_file" ,
1925+ [
1926+ llama_context_p_ctypes ,
1927+ ctypes .c_char_p ,
1928+ llama_seq_id ,
1929+ llama_token_p ,
1930+ ctypes .c_size_t ,
1931+ ],
1932+ ctypes .c_size_t ,
1933+ )
1934+ def llama_state_seq_save_file (
1935+ ctx : llama_context_p ,
1936+ filepath : bytes ,
1937+ seq_id : llama_seq_id ,
1938+ tokens : CtypesArray [llama_token ],
1939+ n_token_count : Union [ctypes .c_size_t , int ],
1940+ / ,
1941+ ) -> int :
1942+ ...
1943+
1944+
1945+ # LLAMA_API size_t llama_state_seq_load_file(
1946+ # struct llama_context * ctx,
1947+ # const char * filepath,
1948+ # llama_seq_id dest_seq_id,
1949+ # llama_token * tokens_out,
1950+ # size_t n_token_capacity,
1951+ # size_t * n_token_count_out);
1952+ @ctypes_function (
1953+ "llama_state_seq_load_file" ,
1954+ [
1955+ llama_context_p_ctypes ,
1956+ ctypes .c_char_p ,
1957+ llama_seq_id ,
1958+ llama_token_p ,
1959+ ctypes .c_size_t ,
1960+ ctypes .POINTER (ctypes .c_size_t ),
1961+ ],
1962+ ctypes .c_size_t ,
1963+ )
1964+ def llama_state_seq_load_file (
1965+ ctx : llama_context_p ,
1966+ filepath : bytes ,
1967+ dest_seq_id : llama_seq_id ,
1968+ tokens_out : CtypesArray [llama_token ],
1969+ n_token_capacity : Union [ctypes .c_size_t , int ],
1970+ n_token_count_out : CtypesPointerOrRef [ctypes .c_size_t ],
1971+ / ,
1972+ ) -> int :
1973+ ...
1974+
1975+
17551976# //
17561977# // Decoding
17571978# //
@@ -1930,8 +2151,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
19302151 ...
19312152
19322153
1933- # // Logits for the ith token. Equivalent to:
2154+ # // Logits for the ith token. For positive indices, Equivalent to:
19342155# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
2156+ # // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
19352157# // returns NULL for invalid ids.
19362158# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
19372159@ctypes_function (
@@ -1963,8 +2185,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
19632185 ...
19642186
19652187
1966- # // Get the embeddings for the ith token. Equivalent to:
2188+ # // Get the embeddings for the ith token. For positive indices, Equivalent to:
19672189# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
2190+ # // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
19682191# // shape: [n_embd] (1-dimensional)
19692192# // returns NULL for invalid ids.
19702193# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
0 commit comments