|
14 | 14 | import numpy as np |
15 | 15 | import tiktoken |
16 | 16 | import yaml |
17 | | -# from langchain.schema import BaseMessage |
18 | 17 | from langchain.schema import BaseMessage as LangchainBaseMessage |
19 | 18 | from langchain_community.adapters.openai import convert_message_to_dict |
20 | 19 | from PIL import Image |
@@ -186,11 +185,15 @@ def get_tokenizer(model_name="gpt-4"): |
186 | 185 | try: |
187 | 186 | return tiktoken.encoding_for_model(model_name) |
188 | 187 | except KeyError: |
189 | | - logging.info(f"Could not find a tokenizer for model {model_name}. Trying HuggingFace.") |
| 188 | + logging.info( |
| 189 | + f"Could not find a tokenizer for model {model_name}. Trying HuggingFace." |
| 190 | + ) |
190 | 191 | try: |
191 | 192 | return AutoTokenizer.from_pretrained(model_name) |
192 | 193 | except OSError: |
193 | | - logging.info(f"Could not find a tokenizer for model {model_name}. Defaulting to gpt-4.") |
| 194 | + logging.info( |
| 195 | + f"Could not find a tokenizer for model {model_name}. Defaulting to gpt-4." |
| 196 | + ) |
194 | 197 | return tiktoken.encoding_for_model("gpt-4") |
195 | 198 |
|
196 | 199 |
|
@@ -402,7 +405,9 @@ def __str__(self, warn_if_image=False) -> str: |
402 | 405 | else: |
403 | 406 | logging.info(msg) |
404 | 407 |
|
405 | | - return "\n".join([elem["text"] for elem in self["content"] if elem["type"] == "text"]) |
| 408 | + return "\n".join( |
| 409 | + [elem["text"] for elem in self["content"] if elem["type"] == "text"] |
| 410 | + ) |
406 | 411 |
|
407 | 412 | def add_content(self, type: str, content: Any): |
408 | 413 | if isinstance(self["content"], str): |
@@ -540,11 +545,12 @@ def __getitem__(self, key): |
540 | 545 |
|
541 | 546 | def to_markdown(self): |
542 | 547 | self.merge() |
543 | | - return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)]) |
| 548 | + return "\n".join( |
| 549 | + [f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)] |
| 550 | + ) |
544 | 551 |
|
545 | 552 |
|
546 | 553 | if __name__ == "__main__": |
547 | | - |
548 | 554 | # model_to_download = "THUDM/agentlm-70b" |
549 | 555 | model_to_download = "databricks/dbrx-instruct" |
550 | 556 | save_dir = "/mnt/ui_copilot/data_rw/base_models/" |
|
0 commit comments