@@ -1694,6 +1694,14 @@ def llama_model_is_recurrent(model: llama_model_p, /) -> bool:
16941694 ...
16951695
16961696
1697+ # // Returns true if the model is hybrid (like Jamba, Granite, etc.)
1698+ # LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
1699+ @ctypes_function ("llama_model_is_hybrid" , [llama_model_p_ctypes ], ctypes .c_bool )
1700+ def llama_model_is_hybrid (model : llama_model_p , / ) -> bool :
1701+ """Returns true if the model is hybrid (like Jamba, Granite, etc.)"""
1702+ ...
1703+
1704+
16971705# // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
16981706# LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
16991707@ctypes_function ("llama_model_is_diffusion" , [llama_model_p_ctypes ], ctypes .c_bool )
@@ -2539,6 +2547,92 @@ def llama_state_seq_load_file(
25392547 ...
25402548
25412549
2550+ # // for backwards-compat
2551+ LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1
2552+
2553+ # // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
2554+ LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1
2555+
2556+ llama_state_seq_flags = ctypes .c_uint32
2557+
2558+ # LLAMA_API size_t llama_state_seq_get_size_ext(
2559+ # struct llama_context * ctx,
2560+ # llama_seq_id seq_id,
2561+ # llama_state_seq_flags flags);
2562+ @ctypes_function (
2563+ "llama_state_seq_get_size_ext" ,
2564+ [
2565+ llama_context_p_ctypes ,
2566+ llama_seq_id ,
2567+ llama_state_seq_flags ,
2568+ ],
2569+ ctypes .c_size_t ,
2570+ )
2571+ def llama_state_seq_get_size_ext (
2572+ ctx : llama_context_p ,
2573+ seq_id : llama_seq_id ,
2574+ flags : llama_state_seq_flags ,
2575+ / ,
2576+ ) -> int :
2577+ ...
2578+
2579+
2580+ # LLAMA_API size_t llama_state_seq_get_data_ext(
2581+ # struct llama_context * ctx,
2582+ # uint8_t * dst,
2583+ # size_t size,
2584+ # llama_seq_id seq_id,
2585+ # llama_state_seq_flags flags);
2586+ @ctypes_function (
2587+ "llama_state_seq_get_data_ext" ,
2588+ [
2589+ llama_context_p_ctypes ,
2590+ ctypes .POINTER (ctypes .c_uint8 ),
2591+ ctypes .c_size_t ,
2592+ llama_seq_id ,
2593+ llama_state_seq_flags ,
2594+ ],
2595+ ctypes .c_size_t ,
2596+ )
2597+ def llama_state_seq_get_data_ext (
2598+ ctx : llama_context_p ,
2599+ dst : ctypes .POINTER (ctypes .c_uint8 ),
2600+ size : Union [int , ctypes .c_size_t ],
2601+ seq_id : llama_seq_id ,
2602+ flags : llama_state_seq_flags ,
2603+ / ,
2604+ ) -> int :
2605+ ...
2606+
2607+
2608+ # LLAMA_API size_t llama_state_seq_set_data_ext(
2609+ # struct llama_context * ctx,
2610+ # const uint8_t * src,
2611+ # size_t size,
2612+ # llama_seq_id dest_seq_id,
2613+ # llama_state_seq_flags flags);
2614+ @ctypes_function (
2615+ "llama_state_seq_set_data_ext" ,
2616+ [
2617+ llama_context_p_ctypes ,
2618+ ctypes .POINTER (ctypes .c_uint8 ),
2619+ ctypes .c_size_t ,
2620+ llama_seq_id ,
2621+ llama_state_seq_flags ,
2622+ ],
2623+ ctypes .c_size_t ,
2624+ )
2625+ def llama_state_seq_set_data_ext (
2626+ ctx : llama_context_p ,
2627+ src : ctypes .POINTER (ctypes .c_uint8 ),
2628+ size : Union [int , ctypes .c_size_t ],
2629+ dest_seq_id : llama_seq_id ,
2630+ flags : llama_state_seq_flags ,
2631+ / ,
2632+ ) -> int :
2633+ ...
2634+
2635+
25422636# //
25432637# // Decoding
25442638# //
0 commit comments