Skip to content

Commit f2c876a

Browse files
committed
Add patch_abc utility
1 parent 64de448 commit f2c876a

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
3+
from langchain_core.language_models import BaseLLM, BaseChatModel
4+
from langchain_core.language_models.base import BaseLanguageModel
5+
6+
# important note: if you import these after patching, the patch won't apply!
7+
# mitigation will be added to patch_abc in a future release
8+
from langchain_huggingface import HuggingFaceEndpoint
9+
from langchain_ollama import OllamaLLM
10+
from langchain_openai import ChatOpenAI
11+
12+
from opentelemetry.util._wrap import patch_abc
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser(description="LangChain model comparison")
17+
parser.add_argument("--provider", choices=["ollama", "openai", "huggingface"], default="ollama",
18+
help="Choose model provider (default: ollama)")
19+
parser.add_argument("--model", type=str, help="Specify model name")
20+
parser.add_argument("--prompt", type=str, default="What is the capital of France?",
21+
help="Input prompt")
22+
23+
return parser.parse_args()
24+
25+
26+
def chat_with_model(model: BaseLanguageModel, prompt: str) -> str:
27+
try:
28+
response = model.invoke(prompt)
29+
if hasattr(response, 'content'):
30+
return response.content
31+
else:
32+
return str(response)
33+
except Exception as e:
34+
return f"Error: {str(e)}"
35+
36+
37+
def create_huggingface_model(model: str = "google/flan-t5-small"):
38+
return HuggingFaceEndpoint(
39+
repo_id=model,
40+
temperature=0.7
41+
)
42+
43+
44+
def create_openai_model(model: str = "gpt-3.5-turbo"):
45+
return ChatOpenAI(
46+
model=model,
47+
temperature=0.7
48+
)
49+
50+
51+
def create_ollama_model(model: str = "llama2"):
52+
return OllamaLLM(
53+
model=model,
54+
temperature=0.7
55+
)
56+
57+
58+
def patch_llm():
59+
def my_wrapper(orig_fcn):
60+
def wrapped_fcn(self, *args, **kwargs):
61+
print("wrapper starting")
62+
print(f"Arguments: {args}")
63+
print(f"Keyword arguments: {kwargs}")
64+
return orig_fcn(self, *args, **kwargs)
65+
66+
return wrapped_fcn
67+
68+
patch_abc(BaseLLM, "_generate", my_wrapper)
69+
70+
# this is for OpenAI, which is a weird case. The _generate method is in a differnt base class and gets called twice.
71+
patch_abc(BaseChatModel, "_generate", my_wrapper)
72+
73+
74+
def main():
75+
args = parse_args()
76+
77+
patch_llm()
78+
79+
if args.provider == "ollama":
80+
model = create_ollama_model(args.model or "llama2")
81+
elif args.provider == "openai":
82+
model = create_openai_model(args.model or "gpt-3.5-turbo")
83+
elif args.provider == "huggingface":
84+
model = create_huggingface_model(args.model or "google/flan-t5-small")
85+
else:
86+
raise ValueError(f"Unsupported provider: {args.provider}")
87+
88+
response = chat_with_model(model, args.prompt)
89+
print(f"{args.provider.title()} Response: {response}")
90+
91+
92+
if __name__ == "__main__":
93+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
langchain-core
2+
langchain-openai
3+
langchain-ollama
4+
langchain-huggingface
5+
huggingface-hub
6+
opentelemetry-api
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
3+
from opentelemetry.util._wrap import patch_abc
4+
5+
6+
class Greeter(ABC):
7+
@abstractmethod
8+
def greet(self):
9+
pass
10+
11+
12+
class EngishGreeter(Greeter):
13+
def greet(self):
14+
print("hello")
15+
16+
17+
class SpanishGreeter(Greeter):
18+
def greet(self):
19+
print("hola")
20+
21+
22+
if __name__ == '__main__':
23+
def my_wrapper(orig_fcn):
24+
def wrapped_fcn(self, *args, **kwargs):
25+
print("wrapper running")
26+
result = orig_fcn(self, *args, **kwargs)
27+
return result
28+
29+
return wrapped_fcn
30+
31+
32+
patch_abc(Greeter, "greet", my_wrapper)
33+
34+
EngishGreeter().greet()
35+
SpanishGreeter().greet()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
def patch_abc(abstract_base_class, method_name, w):
2+
"""
3+
Patches a method across all subclasses of an abstract base class.
4+
"""
5+
subclasses = recursively_get_all_subclasses(abstract_base_class)
6+
for subclass in subclasses:
7+
old_method = getattr(subclass, method_name)
8+
setattr(subclass, method_name, w(old_method))
9+
10+
# This implementation does not work if the instrumented class is imported after the instrumentor runs.
11+
# However, that case can be handled by querying the gc module for all existing classes; this capability can be added
12+
# in a follow-up release.
13+
14+
15+
def recursively_get_all_subclasses(cls):
16+
out = set()
17+
for subclass in cls.__subclasses__():
18+
out.add(subclass)
19+
out.update(recursively_get_all_subclasses(subclass))
20+
return out
21+

0 commit comments

Comments
 (0)