Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/aiperf/common/enums/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class MessageType(CaseInsensitiveStrEnum):
CREDIT_PHASE_START = "credit_phase_start"
CREDIT_PHASES_CONFIGURED = "credit_phases_configured"
CREDITS_COMPLETE = "credits_complete"
DATASET_CLIENT_REQUEST = "dataset_client_request"
DATASET_CLIENT_RESPONSE = "dataset_client_response"
DATASET_CONFIGURED_NOTIFICATION = "dataset_configured_notification"
ERROR = "error"
HEARTBEAT = "heartbeat"
Expand Down
4 changes: 4 additions & 0 deletions src/aiperf/common/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
ConversationResponseMessage,
ConversationTurnRequestMessage,
ConversationTurnResponseMessage,
DatasetClientRequestMessage,
DatasetClientResponseMessage,
DatasetConfiguredNotification,
)
from aiperf.common.messages.inference_messages import (
Expand Down Expand Up @@ -87,6 +89,8 @@
"ConversationResponseMessage",
"ConversationTurnRequestMessage",
"ConversationTurnResponseMessage",
"DatasetClientRequestMessage",
"DatasetClientResponseMessage",
"DatasetConfiguredNotification",
"ErrorMessage",
"HeartbeatMessage",
Expand Down
25 changes: 25 additions & 0 deletions src/aiperf/common/messages/dataset_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,31 @@
from aiperf.common.types import MessageTypeT


class DatasetClientRequestMessage(BaseServiceMessage):
"""Request the dataset client metadata from the DatasetManager."""

message_type: MessageTypeT = MessageType.DATASET_CLIENT_REQUEST


class DatasetClientResponseMessage(BaseServiceMessage):
"""Response containing the dataset client metadata for connection."""

message_type: MessageTypeT = MessageType.DATASET_CLIENT_RESPONSE

client_metadata: SerializeAsAny[DatasetClientMetadata] = Field(
...,
description="Client access metadata (e.g., mmap file paths) for reading the dataset.",
)

@field_validator("client_metadata", mode="before")
@classmethod
def route_client_metadata(cls, v: Any) -> DatasetClientMetadata:
"""Route nested AutoRoutedModel field to correct subclass."""
if isinstance(v, dict):
return DatasetClientMetadata.from_json(v)
return v


class ConversationRequestMessage(BaseServiceMessage):
"""Message to request a full conversation by ID."""

Expand Down
18 changes: 18 additions & 0 deletions src/aiperf/dataset/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
ConversationResponseMessage,
ConversationTurnRequestMessage,
ConversationTurnResponseMessage,
DatasetClientRequestMessage,
DatasetClientResponseMessage,
DatasetConfiguredNotification,
ProfileConfigureCommand,
)
Expand Down Expand Up @@ -464,6 +466,22 @@ async def _handle_conversation_turn_request(
turn=turn,
)

@on_request(MessageType.DATASET_CLIENT_REQUEST)
async def _handle_dataset_client_request(
self, message: DatasetClientRequestMessage
) -> DatasetClientResponseMessage:
"""Return dataset client metadata so the requester can initialize its own client."""
self.debug(lambda: f"Handling dataset client request from {message.service_id}")

await self._wait_for_dataset_configuration()

client_metadata = self._backing_store.get_client_metadata()
return DatasetClientResponseMessage(
service_id=self.service_id,
request_id=message.request_id,
client_metadata=client_metadata,
)

async def _wait_for_dataset_configuration(self) -> None:
"""Wait for the dataset to be configured if it is not already."""
if not self.dataset_configured.is_set():
Expand Down
98 changes: 42 additions & 56 deletions src/aiperf/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
WorkerHealthMessage,
)
from aiperf.common.messages.dataset_messages import (
ConversationRequestMessage,
ConversationResponseMessage,
DatasetClientRequestMessage,
DatasetClientResponseMessage,
)
from aiperf.common.mixins import ProcessHealthMixin
from aiperf.common.models import (
Conversation,
DatasetClientMetadata,
ErrorDetails,
ModelEndpointInfo,
ReasoningResponseData,
Expand Down Expand Up @@ -197,10 +198,9 @@ def __init__(
or self.user_config.loadgen.warmup_prefill_concurrency is not None
)

# Only used as a fallback when dataset client is not initialized
# or was not available when the credit was dropped. Must be created here
# so it can be attached to the worker lifecycle.
self.conversation_request_client: RequestClientProtocol = (
# Used to request dataset client metadata from DatasetManager when
# the dataset client is not yet initialized (race condition at startup).
self._dataset_client_request_client: RequestClientProtocol = (
self.comms.create_request_client(
address=CommAddress.DATASET_MANAGER_PROXY_FRONTEND,
bind=False,
Expand All @@ -221,24 +221,30 @@ async def _send_worker_ready_message(self) -> None:

@on_message(MessageType.DATASET_CONFIGURED_NOTIFICATION)
async def _on_dataset_configured(self, msg: DatasetConfiguredNotification) -> None:
"""Initialize dataset client when configuration is received.
"""Initialize dataset client when configuration is received."""
await self._initialize_dataset_client(msg.client_metadata)
self.debug(
lambda: (
f"Dataset client initialized: type={msg.client_metadata.client_type}"
)
)

async def _initialize_dataset_client(
self, client_metadata: DatasetClientMetadata
) -> None:
"""Create and initialize the dataset client from metadata.

Uses factory pattern to dynamically create the appropriate client.
The factory auto-extracts client_type from client_metadata, leveraging
the discriminated union pattern for type-safe routing. This allows new
storage backends (S3, Redis, etc.) to work without modifying Worker code.
"""
ClientStoreClass = plugins.get_class(
PluginType.DATASET_CLIENT_STORE, msg.client_metadata.client_type
PluginType.DATASET_CLIENT_STORE, client_metadata.client_type
)
self._dataset_client = ClientStoreClass(client_metadata=msg.client_metadata)
self._dataset_client = ClientStoreClass(client_metadata=client_metadata)
await self._dataset_client.initialize()
self._dataset_configured_event.set()
Comment on lines 223 to 247
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't initialize dataset clients from DatasetManager metadata in Kubernetes.

Lines 223-225 and Lines 588-604 treat DatasetManager client_metadata as worker-local in every run mode. But src/aiperf/dataset/dataset_manager.py Lines 356-364 explicitly say that, in Kubernetes, those paths are control-plane paths that workers should ignore until WorkerPodManager provides local download paths. As written, a pod can try to mmap files that do not exist locally. Guard both paths on run type and keep waiting for the downloaded-path source in Kubernetes.

Also applies to: 588-605

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/aiperf/workers/worker.py` around lines 223 - 247, The worker currently
initializes dataset clients directly from DatasetManager metadata in
_on_dataset_configured and _initialize_dataset_client, which causes pods in
Kubernetes to attempt to use control-plane paths; modify these methods to first
check the run mode (e.g., Kubernetes) and, if running in Kubernetes, do NOT
initialize from DatasetClientMetadata immediately but instead wait for the
WorkerPodManager/local-download path to become available (the “downloaded-path”
source) before creating ClientStoreClass; in practice, add a guard around the
existing initialization logic in
_on_dataset_configured/_initialize_dataset_client that checks run type and
either returns/waits for an alternative notification from WorkerPodManager (or
polls for the downloaded-path) and only then instantiates
ClientStoreClass(client_metadata=...) and calls initialize() and
_dataset_configured_event.set().

Comment on lines +232 to 247
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

wc -l src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 92


🏁 Script executed:

# Read the _initialize_dataset_client method and surrounding context
sed -n '230,250p' src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 951


🏁 Script executed:

# Read the lines 580-586 where client is used
sed -n '575,590p' src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 794


🏁 Script executed:

# Read the lines 588-605 fallback request
sed -n '585,610p' src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 1147


🏁 Script executed:

# Search for _on_dataset_configured method
rg -n "_on_dataset_configured" src/aiperf/workers/worker.py -A 10

Repository: ai-dynamo/aiperf

Length of output: 578


🏁 Script executed:

# Search for _dataset_configured_event usage
rg -n "_dataset_configured_event" src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 224


🏁 Script executed:

# Check line 673 context to see how the event is used for synchronization
sed -n '670,685p' src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 761


🏁 Script executed:

# Check if there's any lock mechanism used with _dataset_client
rg -n "asyncio.Lock\|RLock\|Semaphore\|_dataset_client" src/aiperf/workers/worker.py | head -30

Repository: ai-dynamo/aiperf

Length of output: 42


🏁 Script executed:

# Get full context of the on_pull_message handler to see event wait usage
rg -n "@on_pull_message\|_dataset_configured_event" src/aiperf/workers/worker.py -B 2 -A 5

Repository: ai-dynamo/aiperf

Length of output: 42


🏁 Script executed:

# Search for methods that call get_conversation to understand when it's invoked
rg -n "get_conversation" src/aiperf/workers/worker.py -B 3 -A 1

Repository: ai-dynamo/aiperf

Length of output: 565


🏁 Script executed:

# Check the on_message handlers and their concurrency
rg -n "@on_message\|@on_pull_message\|@on_start\|@on_init" src/aiperf/workers/worker.py | head -20

Repository: ai-dynamo/aiperf

Length of output: 42


🏁 Script executed:

# Get more context around line 580 to see the full get_conversation method
sed -n '570,595p' src/aiperf/workers/worker.py

Repository: ai-dynamo/aiperf

Length of output: 1219


🏁 Script executed:

# Check if there's any initialization guard or lock used elsewhere with _dataset_client
rg -n "self._dataset_client" src/aiperf/workers/worker.py | head -20

Repository: ai-dynamo/aiperf

Length of output: 773


Fix race condition in dataset client initialization.

The assignment of self._dataset_client at line 245 occurs before await self._dataset_client.initialize() completes at line 246. This creates a window where concurrent calls to get_conversation() (line 580) can observe self._dataset_client is not None and invoke methods on an uninitialized client. The synchronization event is set only after initialization (line 247) and is awaited only during startup configuration, not during concurrent get_conversation() calls. Additionally, if get_conversation() triggers the fallback request (line 585), it will invoke _initialize_dataset_client() again while the original initialization is still in progress, causing client conflicts or double-initialization. Build the client in a local variable, await initialization, then assign under proper synchronization (lock or fully completed event).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/aiperf/workers/worker.py` around lines 232 - 247, Create the dataset
client locally, fully initialize it, then publish it to the Worker instance
under synchronization to avoid races: in _initialize_dataset_client build the
client in a local variable (e.g., client = ClientStoreClass(...)), await
client.initialize(), then acquire the existing Worker-level lock (or create one
if missing) and assign self._dataset_client = client and set
self._dataset_configured_event inside that critical section; also update
get_conversation to wait on self._dataset_configured_event (or re-check under
the same lock) before using self._dataset_client and guard against duplicate
concurrent initializations by checking the published self._dataset_client after
waiting and returning early if another initializer already set it.

self.debug(
lambda: (
f"Dataset client initialized: type={msg.client_metadata.client_type}"
)
)

@on_stop
async def _send_worker_shutdown_message(self) -> None:
Expand Down Expand Up @@ -576,52 +582,32 @@ async def _retrieve_conversation(
elif self.stop_requested:
raise asyncio.CancelledError("Stop requested while retrieving conversation")

return await self._request_conversation_from_dataset_manager(
conversation_id, credit_context
)
await self._request_dataset_client_from_dataset_manager()
return await self._dataset_client.get_conversation(conversation_id)

async def _request_conversation_from_dataset_manager(
self, conversation_id: str, credit_context: CreditContext
) -> Conversation:
"""Fallback: Request from DatasetManager via ZMQ"""
conversation_response: (
ConversationResponseMessage | ErrorMessage
) = await self.conversation_request_client.request(
ConversationRequestMessage(
service_id=self.service_id,
conversation_id=conversation_id,
credit_phase=credit_context.credit.phase,
)
async def _request_dataset_client_from_dataset_manager(self) -> None:
"""Fallback: Request dataset client metadata from DatasetManager and initialize client."""
self.info(
"Dataset client not available, requesting metadata from DatasetManager"
)
if self.is_trace_enabled:
self.trace(f"Received response message: {conversation_response}")

# Check for error in conversation response
if isinstance(conversation_response, ErrorMessage):
error = conversation_response.error
await self._send_inference_result_message(
RequestRecord(
request_info=RequestInfo(
model_endpoint=self.model_endpoint,
conversation_id=conversation_id,
turn_index=0,
turns=[],
credit_num=credit_context.credit.id,
credit_phase=credit_context.credit.phase,
x_request_id=str(uuid.uuid4()),
x_correlation_id=credit_context.credit.x_correlation_id,
drop_perf_ns=credit_context.drop_perf_ns,
),
model_name=self.model_endpoint.primary_model_name,
timestamp_ns=time.time_ns(),
start_perf_ns=time.perf_counter_ns(),
end_perf_ns=time.perf_counter_ns(),
error=error,
)
response: (
DatasetClientResponseMessage | ErrorMessage
) = await self._dataset_client_request_client.request(
DatasetClientRequestMessage(service_id=self.service_id)
)

if isinstance(response, ErrorMessage):
raise ValueError(
f"Failed to retrieve dataset client metadata: {response.error}"
)
raise ValueError(f"Failed to retrieve conversation response: {error}")

return conversation_response.conversation
await self._initialize_dataset_client(response.client_metadata)
self.info(
lambda: (
f"Dataset client initialized via fallback request: "
f"type={response.client_metadata.client_type}"
)
)

async def _process_response(self, record: RequestRecord) -> Turn | None:
"""Extract assistant response from RequestRecord and convert to Turn for session.
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/workers/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,19 @@ async def test_raises_cancelled_error_when_stop_requested_and_no_client(
async def test_falls_back_to_dataset_manager_when_no_client_and_not_stopping(
self, monkeypatch, mock_worker, sample_credit_context
):
"""When _dataset_client is None and not stopping, should request from DatasetManager."""
"""When _dataset_client is None and not stopping, should request client metadata from DatasetManager."""
mock_worker._dataset_client = None
expected_conversation = Conversation(session_id="test-conv-123", turns=[])
mock_fallback = AsyncMock(return_value=expected_conversation)

async def mock_request_client(self_worker):
mock_client = AsyncMock()
mock_client.get_conversation = AsyncMock(return_value=expected_conversation)
self_worker._dataset_client = mock_client

monkeypatch.setattr(
mock_worker, "_request_conversation_from_dataset_manager", mock_fallback
mock_worker,
"_request_dataset_client_from_dataset_manager",
lambda: mock_request_client(mock_worker),
)

result = await mock_worker._retrieve_conversation(
Expand All @@ -313,4 +320,3 @@ async def test_falls_back_to_dataset_manager_when_no_client_and_not_stopping(
)

assert result == expected_conversation
mock_fallback.assert_called_once_with("test-conv-123", sample_credit_context)