Skip to content

Commit f297ea6

Browse files
committed
Add patch_abc utility
1 parent 64de448 commit f297ea6

File tree

4 files changed

+178
-0
lines changed

4 files changed

+178
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
3+
from langchain_core.language_models import BaseLLM, BaseChatModel
4+
from langchain_core.language_models.base import BaseLanguageModel
5+
6+
# important note: if you import these after patching, the patch won't apply!
7+
# mitigation will be added to patch_abc in a future release
8+
from langchain_huggingface import HuggingFaceEndpoint
9+
from langchain_ollama import OllamaLLM
10+
from langchain_openai import ChatOpenAI
11+
12+
from opentelemetry.util._wrap import patch_abc
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser(description="LangChain model comparison")
17+
parser.add_argument("--provider", choices=["ollama", "openai", "huggingface"], default="ollama",
18+
help="Choose model provider (default: ollama)")
19+
parser.add_argument("--model", type=str, help="Specify model name")
20+
parser.add_argument("--prompt", type=str, default="What is the capital of France?",
21+
help="Input prompt")
22+
23+
return parser.parse_args()
24+
25+
26+
def chat_with_model(model: BaseLanguageModel, prompt: str) -> str:
27+
try:
28+
response = model.invoke(prompt)
29+
if hasattr(response, 'content'):
30+
return response.content
31+
else:
32+
return str(response)
33+
except Exception as e:
34+
return f"Error: {str(e)}"
35+
36+
37+
def create_huggingface_model(model: str = "google/flan-t5-small"):
38+
return HuggingFaceEndpoint(
39+
repo_id=model,
40+
temperature=0.7
41+
)
42+
43+
44+
def create_openai_model(model: str = "gpt-3.5-turbo"):
45+
return ChatOpenAI(
46+
model=model,
47+
temperature=0.7
48+
)
49+
50+
51+
def create_ollama_model(model: str = "llama2"):
52+
return OllamaLLM(
53+
model=model,
54+
temperature=0.7
55+
)
56+
57+
58+
def patch_llm():
59+
def my_wrapper(orig_fcn):
60+
def wrapped_fcn(self, *args, **kwargs):
61+
print("wrapper starting")
62+
print(f"Arguments: {args}")
63+
print(f"Keyword arguments: {kwargs}")
64+
return orig_fcn(self, *args, **kwargs)
65+
66+
return wrapped_fcn
67+
68+
patch_abc(BaseLLM, "_generate", my_wrapper)
69+
70+
# this is for OpenAI, which is a weird case. The _generate method is in a differnt base class and gets called twice.
71+
patch_abc(BaseChatModel, "_generate", my_wrapper)
72+
73+
74+
def main():
75+
args = parse_args()
76+
77+
patch_llm()
78+
79+
if args.provider == "ollama":
80+
model = create_ollama_model(args.model or "llama2")
81+
elif args.provider == "openai":
82+
model = create_openai_model(args.model or "gpt-3.5-turbo")
83+
elif args.provider == "huggingface":
84+
model = create_huggingface_model(args.model or "google/flan-t5-small")
85+
else:
86+
raise ValueError(f"Unsupported provider: {args.provider}")
87+
88+
response = chat_with_model(model, args.prompt)
89+
print(f"{args.provider.title()} Response: {response}")
90+
91+
92+
if __name__ == "__main__":
93+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
langchain-core
2+
langchain-openai
3+
langchain-ollama
4+
langchain-huggingface
5+
huggingface-hub
6+
opentelemetry-api
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
3+
from opentelemetry.util._wrap import patch_abc
4+
5+
6+
class Greeter(ABC):
7+
@abstractmethod
8+
def greet(self):
9+
pass
10+
11+
12+
class EngishGreeter(Greeter):
13+
def greet(self):
14+
print("hello")
15+
16+
17+
class SpanishGreeter(Greeter):
18+
def greet(self):
19+
print("hola")
20+
21+
22+
if __name__ == '__main__':
23+
def my_wrapper(orig_fcn):
24+
def wrapped_fcn(self, *args, **kwargs):
25+
print("wrapper running")
26+
result = orig_fcn(self, *args, **kwargs)
27+
return result
28+
29+
return wrapped_fcn
30+
31+
32+
patch_abc(Greeter, "greet", my_wrapper)
33+
34+
EngishGreeter().greet()
35+
SpanishGreeter().greet()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
def patch_abc(abstract_base_class, method_name, w):
2+
"""
3+
Patches a method on leaf subclasses of an abstract base class.
4+
5+
"""
6+
all_subclasses = recursively_get_all_subclasses(abstract_base_class)
7+
leaf_subclasses = get_leaf_subclasses(all_subclasses)
8+
9+
for subclass in leaf_subclasses:
10+
# Patch if the subclass has the method (either defined or inherited)
11+
# and it's actually callable
12+
if hasattr(subclass, method_name) and callable(getattr(subclass, method_name)):
13+
old_method = getattr(subclass, method_name)
14+
setattr(subclass, method_name, w(old_method))
15+
16+
# This implementation does not work if the instrumented class is imported after the instrumentor runs.
17+
# However, that case can be handled by querying the gc module for all existing classes; this capability can be added
18+
# in a follow-up change.
19+
20+
21+
def get_leaf_subclasses(all_subclasses):
22+
"""
23+
Returns only the leaf classes (classes with no subclasses) from a set of classes.
24+
"""
25+
leaf_classes = set()
26+
for cls in all_subclasses:
27+
# A class is a leaf if no other class in the set is its subclass
28+
is_leaf = True
29+
for other_cls in all_subclasses:
30+
if other_cls != cls and issubclass(other_cls, cls):
31+
is_leaf = False
32+
break
33+
if is_leaf:
34+
leaf_classes.add(cls)
35+
return leaf_classes
36+
37+
38+
39+
def recursively_get_all_subclasses(cls):
40+
out = set()
41+
for subclass in cls.__subclasses__():
42+
out.add(subclass)
43+
out.update(recursively_get_all_subclasses(subclass))
44+
return out

0 commit comments

Comments
 (0)