44import  sys 
55import  uuid 
66import  time 
7+ import  json 
8+ import  fnmatch 
79import  multiprocessing 
810from  typing  import  (
911    List ,
1618    Callable ,
1719)
1820from  collections  import  deque 
21+ from  pathlib  import  Path 
1922
2023import  ctypes 
2124
2932    LlamaDiskCache ,  # type: ignore 
3033    LlamaRAMCache ,  # type: ignore 
3134)
32- from  .llama_tokenizer  import  (
33-     BaseLlamaTokenizer ,
34-     LlamaTokenizer 
35- )
35+ from  .llama_tokenizer  import  BaseLlamaTokenizer , LlamaTokenizer 
3636import  llama_cpp .llama_cpp  as  llama_cpp 
3737import  llama_cpp .llama_chat_format  as  llama_chat_format 
3838
5050    _LlamaSamplingContext ,  # type: ignore 
5151)
5252from  ._logger  import  set_verbose 
53- from  ._utils  import  (
54-     suppress_stdout_stderr 
55- )
53+ from  ._utils  import  suppress_stdout_stderr 
5654
5755
5856class  Llama :
@@ -189,7 +187,11 @@ def __init__(
189187            Llama .__backend_initialized  =  True 
190188
191189        if  isinstance (numa , bool ):
192-             self .numa  =  llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE  if  numa  else  llama_cpp .GGML_NUMA_STRATEGY_DISABLED 
190+             self .numa  =  (
191+                 llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE 
192+                 if  numa 
193+                 else  llama_cpp .GGML_NUMA_STRATEGY_DISABLED 
194+             )
193195        else :
194196            self .numa  =  numa 
195197
@@ -246,17 +248,17 @@ def __init__(
246248                else :
247249                    raise  ValueError (f"Unknown value type for { k }  : { v }  " )
248250
249-             self ._kv_overrides_array [
250-                 - 1 
251-             ]. key   =   b" \0 "    # ensure sentinel element is zeroed 
251+             self ._kv_overrides_array [- 1 ]. key   =  ( 
252+                 b" \0 "    # ensure sentinel element is zeroed 
253+             ) 
252254            self .model_params .kv_overrides  =  self ._kv_overrides_array 
253255
254256        self .n_batch  =  min (n_ctx , n_batch )  # ??? 
255257        self .n_threads  =  n_threads  or  max (multiprocessing .cpu_count () //  2 , 1 )
256258        self .n_threads_batch  =  n_threads_batch  or  max (
257259            multiprocessing .cpu_count () //  2 , 1 
258260        )
259-          
261+ 
260262        # Context Params 
261263        self .context_params  =  llama_cpp .llama_context_default_params ()
262264        self .context_params .seed  =  seed 
@@ -289,7 +291,9 @@ def __init__(
289291        )
290292        self .context_params .yarn_orig_ctx  =  yarn_orig_ctx  if  yarn_orig_ctx  !=  0  else  0 
291293        self .context_params .mul_mat_q  =  mul_mat_q 
292-         self .context_params .logits_all  =  logits_all  if  draft_model  is  None  else  True  # Must be set to True for speculative decoding 
294+         self .context_params .logits_all  =  (
295+             logits_all  if  draft_model  is  None  else  True 
296+         )  # Must be set to True for speculative decoding 
293297        self .context_params .embedding  =  embedding 
294298        self .context_params .offload_kqv  =  offload_kqv 
295299
@@ -379,8 +383,14 @@ def __init__(
379383        if  self .verbose :
380384            print (f"Model metadata: { self .metadata }  " , file = sys .stderr )
381385
382-         if  self .chat_format  is  None  and  self .chat_handler  is  None  and  "tokenizer.chat_template"  in  self .metadata :
383-             chat_format  =  llama_chat_format .guess_chat_format_from_gguf_metadata (self .metadata )
386+         if  (
387+             self .chat_format  is  None 
388+             and  self .chat_handler  is  None 
389+             and  "tokenizer.chat_template"  in  self .metadata 
390+         ):
391+             chat_format  =  llama_chat_format .guess_chat_format_from_gguf_metadata (
392+                 self .metadata 
393+             )
384394
385395            if  chat_format  is  not   None :
386396                self .chat_format  =  chat_format 
@@ -406,9 +416,7 @@ def __init__(
406416                    print (f"Using chat bos_token: { bos_token }  " , file = sys .stderr )
407417
408418                self .chat_handler  =  llama_chat_format .Jinja2ChatFormatter (
409-                     template = template ,
410-                     eos_token = eos_token ,
411-                     bos_token = bos_token 
419+                     template = template , eos_token = eos_token , bos_token = bos_token 
412420                ).to_chat_handler ()
413421
414422        if  self .chat_format  is  None  and  self .chat_handler  is  None :
@@ -459,7 +467,9 @@ def tokenize(
459467        """ 
460468        return  self .tokenizer_ .tokenize (text , add_bos , special )
461469
462-     def  detokenize (self , tokens : List [int ], prev_tokens : Optional [List [int ]] =  None ) ->  bytes :
470+     def  detokenize (
471+         self , tokens : List [int ], prev_tokens : Optional [List [int ]] =  None 
472+     ) ->  bytes :
463473        """Detokenize a list of tokens. 
464474
465475        Args: 
@@ -565,7 +575,7 @@ def sample(
565575            logits [:] =  (
566576                logits_processor (self ._input_ids , logits )
567577                if  idx  is  None 
568-                 else  logits_processor (self ._input_ids [:idx  +  1 ], logits )
578+                 else  logits_processor (self ._input_ids [:  idx  +  1 ], logits )
569579            )
570580
571581        sampling_params  =  _LlamaSamplingParams (
@@ -707,7 +717,9 @@ def generate(
707717
708718            if  self .draft_model  is  not   None :
709719                self .input_ids [self .n_tokens  : self .n_tokens  +  len (tokens )] =  tokens 
710-                 draft_tokens  =  self .draft_model (self .input_ids [:self .n_tokens  +  len (tokens )])
720+                 draft_tokens  =  self .draft_model (
721+                     self .input_ids [: self .n_tokens  +  len (tokens )]
722+                 )
711723                tokens .extend (
712724                    draft_tokens .astype (int )[
713725                        : self ._n_ctx  -  self .n_tokens  -  len (tokens )
@@ -792,6 +804,7 @@ def embed(
792804
793805        # decode and fetch embeddings 
794806        data : List [List [float ]] =  []
807+ 
795808        def  decode_batch (n_seq : int ):
796809            assert  self ._ctx .ctx  is  not   None 
797810            llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
@@ -800,9 +813,9 @@ def decode_batch(n_seq: int):
800813
801814            # store embeddings 
802815            for  i  in  range (n_seq ):
803-                 embedding : List [float ] =  llama_cpp .llama_get_embeddings_ith (self . _ctx . ctx ,  i )[ 
804-                     : n_embd 
805-                 ]
816+                 embedding : List [float ] =  llama_cpp .llama_get_embeddings_ith (
817+                     self . _ctx . ctx ,  i 
818+                 )[: n_embd ]
806819                if  normalize :
807820                    norm  =  float (np .linalg .norm (embedding ))
808821                    embedding  =  [v  /  norm  for  v  in  embedding ]
@@ -1669,12 +1682,13 @@ def create_chat_completion_openai_v1(
16691682        """ 
16701683        try :
16711684            from  openai .types .chat  import  ChatCompletion , ChatCompletionChunk 
1672-             stream  =  kwargs .get ("stream" , False ) # type: ignore 
1685+ 
1686+             stream  =  kwargs .get ("stream" , False )  # type: ignore 
16731687            assert  isinstance (stream , bool )
16741688            if  stream :
1675-                 return  (ChatCompletionChunk (** chunk ) for  chunk  in  self .create_chat_completion (* args , ** kwargs )) # type: ignore 
1689+                 return  (ChatCompletionChunk (** chunk ) for  chunk  in  self .create_chat_completion (* args , ** kwargs ))   # type: ignore 
16761690            else :
1677-                 return  ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore 
1691+                 return  ChatCompletion (** self .create_chat_completion (* args , ** kwargs ))   # type: ignore 
16781692        except  ImportError :
16791693            raise  ImportError (
16801694                "To use create_chat_completion_openai_v1, you must install the openai package." 
@@ -1866,7 +1880,88 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
18661880                break 
18671881        return  longest_prefix 
18681882
1883+     @classmethod  
1884+     def  from_pretrained (
1885+         cls ,
1886+         repo_id : str ,
1887+         filename : Optional [str ],
1888+         local_dir : Optional [Union [str , os .PathLike [str ]]] =  "." ,
1889+         local_dir_use_symlinks : Union [bool , Literal ["auto" ]] =  "auto" ,
1890+         ** kwargs : Any ,
1891+     ) ->  "Llama" :
1892+         """Create a Llama model from a pretrained model name or path. 
1893+         This method requires the huggingface-hub package. 
1894+         You can install it with `pip install huggingface-hub`. 
1895+ 
1896+         Args: 
1897+             repo_id: The model repo id. 
1898+             filename: A filename or glob pattern to match the model file in the repo. 
1899+             local_dir: The local directory to save the model to. 
1900+             local_dir_use_symlinks: Whether to use symlinks when downloading the model. 
1901+             **kwargs: Additional keyword arguments to pass to the Llama constructor. 
1902+ 
1903+         Returns: 
1904+             A Llama model.""" 
1905+         try :
1906+             from  huggingface_hub  import  hf_hub_download , HfFileSystem 
1907+             from  huggingface_hub .utils  import  validate_repo_id 
1908+         except  ImportError :
1909+             raise  ImportError (
1910+                 "Llama.from_pretrained requires the huggingface-hub package. " 
1911+                 "You can install it with `pip install huggingface-hub`." 
1912+             )
1913+ 
1914+         validate_repo_id (repo_id )
1915+ 
1916+         hffs  =  HfFileSystem ()
1917+ 
1918+         files  =  [
1919+             file ["name" ] if  isinstance (file , dict ) else  file 
1920+             for  file  in  hffs .ls (repo_id )
1921+         ]
1922+ 
1923+         # split each file into repo_id, subfolder, filename 
1924+         file_list : List [str ] =  []
1925+         for  file  in  files :
1926+             rel_path  =  Path (file ).relative_to (repo_id )
1927+             file_list .append (str (rel_path ))
18691928
1929+         matching_files  =  [file  for  file  in  file_list  if  fnmatch .fnmatch (file , filename )]  # type: ignore 
1930+ 
1931+         if  len (matching_files ) ==  0 :
1932+             raise  ValueError (
1933+                 f"No file found in { repo_id }   that match { filename } \n \n " 
1934+                 f"Available Files:\n { json .dumps (file_list )}  " 
1935+             )
1936+ 
1937+         if  len (matching_files ) >  1 :
1938+             raise  ValueError (
1939+                 f"Multiple files found in { repo_id }   matching { filename } \n \n " 
1940+                 f"Available Files:\n { json .dumps (files )}  " 
1941+             )
1942+ 
1943+         (matching_file ,) =  matching_files 
1944+ 
1945+         subfolder  =  str (Path (matching_file ).parent )
1946+         filename  =  Path (matching_file ).name 
1947+ 
1948+         local_dir  =  "." 
1949+ 
1950+         # download the file 
1951+         hf_hub_download (
1952+             repo_id = repo_id ,
1953+             local_dir = local_dir ,
1954+             filename = filename ,
1955+             subfolder = subfolder ,
1956+             local_dir_use_symlinks = local_dir_use_symlinks ,
1957+         )
1958+ 
1959+         model_path  =  os .path .join (local_dir , filename )
1960+ 
1961+         return  cls (
1962+             model_path = model_path ,
1963+             ** kwargs ,
1964+         )
18701965
18711966
18721967class  LlamaState :
0 commit comments