Skip to content

Commit 61106a2

Browse files
fix: formatting
1 parent 3d293b4 commit 61106a2

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/agentlab/llm/llm_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
import tiktoken
1616
import yaml
17-
# from langchain.schema import BaseMessage
1817
from langchain.schema import BaseMessage as LangchainBaseMessage
1918
from langchain_community.adapters.openai import convert_message_to_dict
2019
from PIL import Image
@@ -186,11 +185,15 @@ def get_tokenizer(model_name="gpt-4"):
186185
try:
187186
return tiktoken.encoding_for_model(model_name)
188187
except KeyError:
189-
logging.info(f"Could not find a tokenizer for model {model_name}. Trying HuggingFace.")
188+
logging.info(
189+
f"Could not find a tokenizer for model {model_name}. Trying HuggingFace."
190+
)
190191
try:
191192
return AutoTokenizer.from_pretrained(model_name)
192193
except OSError:
193-
logging.info(f"Could not find a tokenizer for model {model_name}. Defaulting to gpt-4.")
194+
logging.info(
195+
f"Could not find a tokenizer for model {model_name}. Defaulting to gpt-4."
196+
)
194197
return tiktoken.encoding_for_model("gpt-4")
195198

196199

@@ -402,7 +405,9 @@ def __str__(self, warn_if_image=False) -> str:
402405
else:
403406
logging.info(msg)
404407

405-
return "\n".join([elem["text"] for elem in self["content"] if elem["type"] == "text"])
408+
return "\n".join(
409+
[elem["text"] for elem in self["content"] if elem["type"] == "text"]
410+
)
406411

407412
def add_content(self, type: str, content: Any):
408413
if isinstance(self["content"], str):
@@ -540,11 +545,12 @@ def __getitem__(self, key):
540545

541546
def to_markdown(self):
542547
self.merge()
543-
return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)])
548+
return "\n".join(
549+
[f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)]
550+
)
544551

545552

546553
if __name__ == "__main__":
547-
548554
# model_to_download = "THUDM/agentlm-70b"
549555
model_to_download = "databricks/dbrx-instruct"
550556
save_dir = "/mnt/ui_copilot/data_rw/base_models/"

0 commit comments

Comments
 (0)