Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/blazingmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from . import exceptions
from . import session_events
from ._about import __version__
from ._authncb import BasicAuthnCredentialCb
from ._enums import AckStatus
from ._enums import CompressionAlgorithmType
from ._enums import PropertyType
Expand All @@ -34,6 +35,7 @@
__all__ = [
"Ack",
"AckStatus",
"BasicAuthnCredentialCb",
"BasicHealthMonitor",
"CompressionAlgorithmType",
"Error",
Expand Down
32 changes: 32 additions & 0 deletions src/blazingmq/_authncb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2019-2023 Bloomberg Finance L.P.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
from typing import Callable, Optional, Tuple
from ._ext import FakeAuthnCredentialCb

CredentialTuple = Tuple[str, bytes]


class BasicAuthnCredentialCb:
"""Wrap a Python callable returning (mechanism:str, data:bytes) or None."""

def __init__(self, callback: Callable[[], Optional[CredentialTuple]]):
if not callable(callback):
raise TypeError("callback must be callable")
self._authncb = FakeAuthnCredentialCb(callback)

def __repr__(self) -> str:
return "BasicAuthnCredentialCb(...)"
4 changes: 4 additions & 0 deletions src/blazingmq/_ext.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class FakeHostHealthMonitor:
def set_healthy(self) -> None: ...
def set_unhealthy(self) -> None: ...

class FakeAuthnCredentialCb:
def __init__(self, callback: Callable[[], Optional[tuple[str, bytes]]]) -> None: ...

class Session:
def __init__(
self,
Expand All @@ -53,6 +56,7 @@ class Session:
timeouts: Timeouts = Timeouts(),
monitor_host_health: bool = False,
fake_host_health_monitor: Optional[FakeHostHealthMonitor] = None,
fake_authn_credential_cb: Optional[FakeAuthnCredentialCb] = None,
) -> None: ...
def stop(self) -> None: ...
def open_queue_sync(
Expand Down
37 changes: 37 additions & 0 deletions src/blazingmq/_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import weakref
from bsl cimport optional
from bsl cimport pair
from bsl cimport shared_ptr
from bsl cimport vector
from bsl cimport string
from bsl.bsls cimport TimeInterval
from cpython.ceval cimport PyEval_InitThreads
from libcpp cimport bool as cppbool

from bmq.bmqa cimport ManualHostHealthMonitor
from bmq.bmqt cimport AckResult
from bmq.bmqt cimport AuthnCredential
from bmq.bmqt cimport CompressionAlgorithmType
from bmq.bmqt cimport HostHealthState
from bmq.bmqt cimport PropertyType
Expand Down Expand Up @@ -153,6 +156,38 @@ cdef class FakeHostHealthMonitor:
self._monitor.get().setState(HostHealthState.e_UNHEALTHY)


cdef class FakeAuthnCredentialCb:
cdef object _callback # Store the Python callable

def __cinit__(self, callback):
if not callable(callback):
raise TypeError("callback must be callable")
self._callback = callback

# This method will be called by C++ code via PyObject_CallMethod
# Returns None for no credential, or (mechanism, data) tuple
def get_credential_data(self):
try:
result = self._callback()
if result is None:
return None

if not isinstance(result, tuple) or len(result) != 2:
raise ValueError("callback must return (str, bytes) or None")

mechanism, data = result
if not isinstance(mechanism, str) or not isinstance(data, bytes):
raise ValueError("callback must return (str, bytes) or None")

# Return as-is, let C++ side handle conversion
return result

except Exception as e:
# Log error or handle as needed
LOGGER.exception("Error in authentication credential callback")
return None


cdef class Session:
cdef object __weakref__
cdef NativeSession* _session
Expand All @@ -173,6 +208,7 @@ cdef class Session:
timeouts: _timeouts.Timeouts = _timeouts.Timeouts(),
monitor_host_health: bool = False,
fake_host_health_monitor: FakeHostHealthMonitor = None,
fake_authn_credential_cb: FakeAuthnCredentialCb = None,
_mock: Optional[object] = None,
) -> None:
cdef shared_ptr[ManualHostHealthMonitor] fake_host_health_monitor_sp
Expand Down Expand Up @@ -224,6 +260,7 @@ cdef class Session:
session_cb,
message_cb,
ack_cb,
fake_authn_credential_cb,
c_broker_uri,
c_script_name,
COMPRESSION_ALGO_FROM_PY_MAPPING[message_compression_algorithm],
Expand Down
4 changes: 4 additions & 0 deletions src/blazingmq/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._messages import Message
from ._messages import MessageHandle
from ._monitors import BasicHealthMonitor
from ._authncb import BasicAuthnCredentialCb
from ._timeouts import Timeouts
from ._typing import PropertyTypeDict
from ._typing import PropertyValueDict
Expand Down Expand Up @@ -418,6 +419,7 @@ def __init__(
),
timeout: Union[Timeouts, float] = DEFAULT_TIMEOUT,
host_health_monitor: Union[BasicHealthMonitor, None] = (DefaultMonitor()),
authn_credential_cb: Optional[BasicAuthnCredentialCb] = None,
num_processing_threads: Optional[int] = None,
blob_buffer_size: Optional[int] = None,
channel_high_watermark: Optional[int] = None,
Expand All @@ -433,6 +435,7 @@ def __init__(

monitor_host_health = host_health_monitor is not None
fake_host_health_monitor = getattr(host_health_monitor, "_monitor", None)
fake_authn_credential_cb = getattr(authn_credential_cb, "_authncb", None)

self._has_no_on_message = on_message is None

Expand All @@ -459,6 +462,7 @@ def __init__(
timeouts=_validate_timeouts(timeout),
monitor_host_health=monitor_host_health,
fake_host_health_monitor=fake_host_health_monitor,
fake_authn_credential_cb=fake_authn_credential_cb,
)

@classmethod
Expand Down
81 changes: 78 additions & 3 deletions src/cpp/pybmq_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <pybmq_session.h>

#include <pybmq_gilacquireguard.h>
#include <pybmq_gilreleaseguard.h>
#include <pybmq_messageutils.h>
#include <pybmq_mocksession.h>
Expand Down Expand Up @@ -77,13 +78,14 @@ Session::Session(
PyObject* py_session_event_callback,
PyObject* py_message_event_callback,
PyObject* py_ack_event_callback,
PyObject* fake_authn_credential_cb,
const char* broker_uri,
const char* script_name,
bmqt::CompressionAlgorithmType::Enum message_compression_type,
bsl::optional<int> num_processing_threads,
bsl::optional<int> blob_buffer_size,
bsl::optional<int> channel_high_watermark,
bsl::optional<bsl::pair<int, int> > event_queue_watermarks,
bsl::optional<bsl::pair<int, int>> event_queue_watermarks,
const bsls::TimeInterval& stats_dump_interval,
const bsls::TimeInterval& connect_timeout,
const bsls::TimeInterval& disconnect_timeout,
Expand Down Expand Up @@ -119,6 +121,74 @@ Session::Session(
}

d_message_compression_type = message_compression_type;

AuthnCredentialCb cpp_callback;
bool has_auth_callback = false;

if (fake_authn_credential_cb != nullptr && fake_authn_credential_cb != Py_None) {
// Increment reference count since we're storing the Python object
Py_INCREF(fake_authn_credential_cb);
has_auth_callback = true;

// Create a C++ lambda that wraps the Python callback
cpp_callback =
[fake_authn_credential_cb](
bsl::ostream& error) -> bsl::optional<bmqt::AuthnCredential> {
pybmq::GilAcquireGuard guard;

// Call get_credential_data() method on the Python object
bslma::ManagedPtr<PyObject> result =
RefUtils::toManagedPtr(PyObject_CallMethod(
fake_authn_credential_cb,
"get_credential_data",
nullptr));

if (!result) {
// Python exception occurred
PyErr_Print();
error << "Error calling get_credential_data()";
return bsl::optional<bmqt::AuthnCredential>();
}

if (result.get() == Py_None) {
return bsl::optional<bmqt::AuthnCredential>();
}

// Extract tuple (mechanism, data)
if (!PyTuple_Check(result.get()) || PyTuple_Size(result.get()) != 2) {
error << "get_credential_data() must return (str, bytes) or None";
return bsl::optional<bmqt::AuthnCredential>();
}

PyObject* mechanism_obj = PyTuple_GetItem(result.get(), 0);
PyObject* data_obj = PyTuple_GetItem(result.get(), 1);

if (!PyUnicode_Check(mechanism_obj) || !PyBytes_Check(data_obj)) {
error << "get_credential_data() must return (str, bytes) or None";
return bsl::optional<bmqt::AuthnCredential>();
}

// Convert Python str to C++ string
const char* mechanism_cstr = PyUnicode_AsUTF8(mechanism_obj);
bsl::string mechanism(mechanism_cstr);

// Convert Python bytes to vector<char>
char* data_ptr;
Py_ssize_t data_len;
PyBytes_AsStringAndSize(data_obj, &data_ptr, &data_len);
bsl::vector<char> data(data_ptr, data_ptr + data_len);

// Construct and return AuthnCredential
bmqt::AuthnCredential credential;
credential.setMechanism(mechanism).setData(data);

// Move credential into optional (AuthnCredential is move-only)
bsl::optional<bmqt::AuthnCredential> opt_credential;
opt_credential.emplace(bslmf::MovableRefUtil::move(credential));
return opt_credential;
};
}

{
pybmq::GilReleaseGuard guard;
bmqt::SessionOptions options;
Expand All @@ -144,6 +214,11 @@ Session::Session(
event_queue_watermarks.value().second);
}

if (has_auth_callback) {
// TODO: This will only compile with setAuthnCredentialCb in SessionOptions
options.setAuthnCredentialCb(cpp_callback);
}

if (stats_dump_interval != bsls::TimeInterval()) {
options.setStatsDumpInterval(stats_dump_interval);
}
Expand Down Expand Up @@ -527,8 +602,8 @@ Session::post(
oss << "Failed to post message to " << queue_uri << " queue: " << post_rc;
throw GenericError(oss.str());
}
// We have a successful post and the SDK now owns the `on_ack` callback object
// so release our reference without a DECREF.
// We have a successful post and the SDK now owns the `on_ack` callback
// object so release our reference without a DECREF.
managed_on_ack.release();
} catch (const GenericError& exc) {
PyErr_SetString(d_error, exc.what());
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/pybmq_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <bmqa_abstractsession.h>
#include <bmqa_manualhosthealthmonitor.h>
#include <bmqt_authncredential.h>
#include <bmqt_compressionalgorithmtype.h>

#include <bsl_memory.h>
Expand All @@ -47,10 +48,15 @@ class Session
Session(const Session&);
Session& operator=(const Session&);

// TODO: Remove this once it's added in SessionOptions
typedef bsl::function<bsl::optional<bmqt::AuthnCredential>(bsl::ostream& error)>
AuthnCredentialCb;

public:
Session(PyObject* py_session_event_callback,
PyObject* py_message_event_callback,
PyObject* py_ack_event_callback,
PyObject* fake_authn_credential_cb,
const char* broker_uri,
const char* script_name,
bmqt::CompressionAlgorithmType::Enum message_compression_type,
Expand Down
10 changes: 10 additions & 0 deletions src/declarations/bmq/bmqt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

from libcpp cimport bool
from bsl cimport string
from bsl cimport vector


cdef extern from "bmqt_sessioneventtype.h" namespace "BloombergLP::bmqt::SessionEventType" nogil:
Expand Down Expand Up @@ -73,3 +75,11 @@ cdef extern from "bmqt_queueoptions.h" namespace "BloombergLP::bmqt::QueueOption
int k_DEFAULT_MAX_UNCONFIRMED_BYTES
int k_DEFAULT_CONSUMER_PRIORITY
bool k_DEFAULT_SUSPENDS_ON_BAD_HOST_HEALTH

cdef extern from "bmqt_authncredential.h" namespace "BloombergLP::bmqt" nogil:
cdef cppclass AuthnCredential:
AuthnCredential() except +
AuthnCredential& setMechanism(const string&) except +
AuthnCredential& setData(const vector[char]&) except +
const string& mechanism() const
const vector[char]& data() const
1 change: 1 addition & 0 deletions src/declarations/pybmq.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ cdef extern from "pybmq_session.h" namespace "BloombergLP::pybmq" nogil:
Session(object on_session_event,
object on_message_event,
object on_ack_event,
object fake_authn_credential_cb,
const char* broker_uri,
const char* script_name,
CompressionAlgorithmType message_compression_algorithm,
Expand Down
Loading