|
25 | 25 | from ddtrace.contrib.trace_utils import wrap |
26 | 26 | from ddtrace.internal.agent import get_stats_url |
27 | 27 | from ddtrace.internal.logger import get_logger |
| 28 | +from ddtrace.internal.utils import ArgumentError |
28 | 29 | from ddtrace.internal.utils import get_argument_value |
29 | 30 | from ddtrace.internal.utils.formats import asbool |
30 | 31 | from ddtrace.internal.utils.formats import deep_getattr |
@@ -326,7 +327,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs): |
326 | 327 | @with_traced_module |
327 | 328 | def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs): |
328 | 329 | llm_provider = instance._llm_type.split("-")[0] |
329 | | - chat_messages = get_argument_value(args, kwargs, 0, "chat_messages") |
| 330 | + chat_messages = get_argument_value(args, kwargs, 0, "messages") |
330 | 331 | integration = langchain._datadog_integration |
331 | 332 | span = integration.trace( |
332 | 333 | pin, |
@@ -417,7 +418,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs): |
417 | 418 | @with_traced_module |
418 | 419 | async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwargs): |
419 | 420 | llm_provider = instance._llm_type.split("-")[0] |
420 | | - chat_messages = get_argument_value(args, kwargs, 0, "chat_messages") |
| 421 | + chat_messages = get_argument_value(args, kwargs, 0, "messages") |
421 | 422 | integration = langchain._datadog_integration |
422 | 423 | span = integration.trace( |
423 | 424 | pin, |
@@ -507,7 +508,15 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar |
507 | 508 |
|
508 | 509 | @with_traced_module |
509 | 510 | def traced_embedding(langchain, pin, func, instance, args, kwargs): |
510 | | - input_texts = get_argument_value(args, kwargs, 0, "text") |
| 511 | + """ |
| 512 | + This traces both embed_query(text) and embed_documents(texts), so we need to make sure |
| 513 | + we get the right arg/kwarg. |
| 514 | + """ |
| 515 | + try: |
| 516 | + input_texts = get_argument_value(args, kwargs, 0, "texts") |
| 517 | + except ArgumentError: |
| 518 | + input_texts = get_argument_value(args, kwargs, 0, "text") |
| 519 | + |
511 | 520 | provider = instance.__class__.__name__.split("Embeddings")[0].lower() |
512 | 521 | integration = langchain._datadog_integration |
513 | 522 | span = integration.trace( |
@@ -559,7 +568,7 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs): |
559 | 568 | span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain") |
560 | 569 | final_outputs = {} |
561 | 570 | try: |
562 | | - inputs = args[0] |
| 571 | + inputs = get_argument_value(args, kwargs, 0, "inputs") |
563 | 572 | if not isinstance(inputs, dict): |
564 | 573 | inputs = {instance.input_keys[0]: inputs} |
565 | 574 | if integration.is_pc_sampled_span(span): |
@@ -605,7 +614,7 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs): |
605 | 614 | span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain") |
606 | 615 | final_outputs = {} |
607 | 616 | try: |
608 | | - inputs = args[0] |
| 617 | + inputs = get_argument_value(args, kwargs, 0, "inputs") |
609 | 618 | if not isinstance(inputs, dict): |
610 | 619 | inputs = {instance.input_keys[0]: inputs} |
611 | 620 | if integration.is_pc_sampled_span(span): |
|
0 commit comments