|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import Optional, Sequence |
| 5 | + |
| 6 | +from dapr_agents.agents.configs import WorkflowGrpcOptions |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at |
| 12 | +# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19 |
| 13 | +# but with my option overrides applied above. |
| 14 | +def apply_grpc_options(options: Optional[WorkflowGrpcOptions]) -> None: |
| 15 | + """ |
| 16 | + Patch Durable Task's gRPC channel factory with custom message size limits. |
| 17 | +
|
| 18 | + Durable Task (and therefore Dapr Workflows) creates its gRPC channels via |
| 19 | + ``durabletask.internal.shared.get_grpc_channel``. This helper monkey patches |
| 20 | + that factory so that subsequent runtime/client instances honour the provided |
| 21 | + ``grpc.max_send_message_length`` / ``grpc.max_receive_message_length`` values. |
| 22 | +
|
| 23 | + Users can set either or both options; any non-None value will be applied. |
| 24 | + """ |
| 25 | + if not options: |
| 26 | + return |
| 27 | + # Early return if neither option is set |
| 28 | + if ( |
| 29 | + options.max_send_message_length is None |
| 30 | + and options.max_receive_message_length is None |
| 31 | + ): |
| 32 | + return |
| 33 | + |
| 34 | + try: |
| 35 | + import grpc |
| 36 | + from durabletask.internal import shared |
| 37 | + except ImportError as exc: |
| 38 | + logger.error( |
| 39 | + "Failed to import grpc/durabletask for channel configuration: %s", exc |
| 40 | + ) |
| 41 | + raise |
| 42 | + |
| 43 | + grpc_options = [] |
| 44 | + if options.max_send_message_length: |
| 45 | + grpc_options.append( |
| 46 | + ("grpc.max_send_message_length", options.max_send_message_length) |
| 47 | + ) |
| 48 | + if options.max_receive_message_length: |
| 49 | + grpc_options.append( |
| 50 | + ("grpc.max_receive_message_length", options.max_receive_message_length) |
| 51 | + ) |
| 52 | + |
| 53 | + def get_grpc_channel_with_options( |
| 54 | + host_address: Optional[str], |
| 55 | + secure_channel: bool = False, |
| 56 | + interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None, |
| 57 | + ): |
| 58 | + if host_address is None: |
| 59 | + host_address = shared.get_default_host_address() |
| 60 | + |
| 61 | + for protocol in getattr(shared, "SECURE_PROTOCOLS", []): |
| 62 | + if host_address.lower().startswith(protocol): |
| 63 | + secure_channel = True |
| 64 | + host_address = host_address[len(protocol) :] |
| 65 | + break |
| 66 | + |
| 67 | + for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): |
| 68 | + if host_address.lower().startswith(protocol): |
| 69 | + secure_channel = False |
| 70 | + host_address = host_address[len(protocol) :] |
| 71 | + break |
| 72 | + |
| 73 | + if secure_channel: |
| 74 | + credentials = grpc.ssl_channel_credentials() |
| 75 | + channel = grpc.secure_channel( |
| 76 | + host_address, credentials, options=grpc_options |
| 77 | + ) |
| 78 | + else: |
| 79 | + channel = grpc.insecure_channel(host_address, options=grpc_options) |
| 80 | + |
| 81 | + if interceptors: |
| 82 | + channel = grpc.intercept_channel(channel, *interceptors) |
| 83 | + |
| 84 | + return channel |
| 85 | + |
| 86 | + shared.get_grpc_channel = get_grpc_channel_with_options |
| 87 | + logger.debug( |
| 88 | + "Applied gRPC options to durabletask channel factory: %s", dict(grpc_options) |
| 89 | + ) |
0 commit comments