|
1 | 1 | """ |
2 | 2 | Copyright (c) 2024 Scale3 Labs |
3 | | -
|
4 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 4 | you may not use this file except in compliance with the License. |
6 | 5 | You may obtain a copy of the License at |
7 | | -
|
8 | 6 | http://www.apache.org/licenses/LICENSE-2.0 |
9 | | -
|
10 | 7 | Unless required by applicable law or agreed to in writing, software |
11 | 8 | distributed under the License is distributed on an "AS IS" BASIS, |
12 | 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 10 | See the License for the specific language governing permissions and |
14 | 11 | limitations under the License. |
15 | 12 | """ |
16 | 13 |
|
| 14 | +from typing import Collection, Optional, Any |
17 | 15 | import importlib.metadata |
18 | 16 | import logging |
19 | | -from typing import Collection |
20 | 17 |
|
21 | 18 | from opentelemetry.instrumentation.instrumentor import BaseInstrumentor |
22 | | -from opentelemetry.trace import get_tracer |
| 19 | +from opentelemetry.trace import get_tracer, TracerProvider |
23 | 20 | from wrapt import wrap_function_wrapper |
24 | 21 |
|
25 | 22 | from langtrace_python_sdk.instrumentation.openai.patch import ( |
|
35 | 32 | logging.basicConfig(level=logging.FATAL) |
36 | 33 |
|
37 | 34 |
|
38 | | -class OpenAIInstrumentation(BaseInstrumentor): |
| 35 | +class OpenAIInstrumentation(BaseInstrumentor): # type: ignore |
39 | 36 |
|
40 | 37 | def instrumentation_dependencies(self) -> Collection[str]: |
41 | 38 | return ["openai >= 0.27.0", "trace-attributes >= 4.0.5"] |
42 | 39 |
|
43 | | - def _instrument(self, **kwargs): |
44 | | - tracer_provider = kwargs.get("tracer_provider") |
| 40 | + def _instrument(self, **kwargs: Any) -> None: |
| 41 | + tracer_provider: Optional[TracerProvider] = kwargs.get("tracer_provider") |
45 | 42 | tracer = get_tracer(__name__, "", tracer_provider) |
46 | | - version = importlib.metadata.version("openai") |
| 43 | + version: str = importlib.metadata.version("openai") |
47 | 44 |
|
48 | 45 | wrap_function_wrapper( |
49 | 46 | "openai.resources.chat.completions", |
50 | 47 | "Completions.create", |
51 | | - chat_completions_create("openai.chat.completions.create", version, tracer), |
| 48 | + chat_completions_create(version, tracer), |
52 | 49 | ) |
53 | 50 |
|
54 | 51 | wrap_function_wrapper( |
55 | 52 | "openai.resources.chat.completions", |
56 | 53 | "AsyncCompletions.create", |
57 | | - async_chat_completions_create( |
58 | | - "openai.chat.completions.create_stream", version, tracer |
59 | | - ), |
| 54 | + async_chat_completions_create(version, tracer), |
60 | 55 | ) |
61 | 56 |
|
62 | 57 | wrap_function_wrapper( |
63 | 58 | "openai.resources.images", |
64 | 59 | "Images.generate", |
65 | | - images_generate("openai.images.generate", version, tracer), |
| 60 | + images_generate(version, tracer), |
66 | 61 | ) |
67 | 62 |
|
68 | 63 | wrap_function_wrapper( |
69 | 64 | "openai.resources.images", |
70 | 65 | "AsyncImages.generate", |
71 | | - async_images_generate("openai.images.generate", version, tracer), |
| 66 | + async_images_generate(version, tracer), |
72 | 67 | ) |
73 | 68 |
|
74 | 69 | wrap_function_wrapper( |
75 | 70 | "openai.resources.images", |
76 | 71 | "Images.edit", |
77 | | - images_edit("openai.images.edit", version, tracer), |
| 72 | + images_edit(version, tracer), |
78 | 73 | ) |
79 | 74 |
|
80 | 75 | wrap_function_wrapper( |
81 | 76 | "openai.resources.embeddings", |
82 | 77 | "Embeddings.create", |
83 | | - embeddings_create("openai.embeddings.create", version, tracer), |
| 78 | + embeddings_create(version, tracer), |
84 | 79 | ) |
85 | 80 |
|
86 | 81 | wrap_function_wrapper( |
87 | 82 | "openai.resources.embeddings", |
88 | 83 | "AsyncEmbeddings.create", |
89 | | - async_embeddings_create("openai.embeddings.create", version, tracer), |
| 84 | + async_embeddings_create(version, tracer), |
90 | 85 | ) |
91 | 86 |
|
92 | | - def _uninstrument(self, **kwargs): |
| 87 | + def _uninstrument(self, **kwargs: Any) -> None: |
93 | 88 | pass |
0 commit comments