Skip to content

Commit a2bbdbe

Browse files
committed
Update types throughout
1 parent f772177 commit a2bbdbe

34 files changed

+193
-189
lines changed

src/a2a/server/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class TaskMixin:
123123
"""Mixin providing standard task columns with proper type handling."""
124124

125125
id: Mapped[str] = mapped_column(String(36), primary_key=True, index=True)
126-
contextId: Mapped[str] = mapped_column(String(36), nullable=False) # noqa: N815
126+
context_id: Mapped[str] = mapped_column(String(36), nullable=False)
127127
kind: Mapped[str] = mapped_column(
128128
String(16), nullable=False, default='task'
129129
)
@@ -148,12 +148,12 @@ def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
148148
def __repr__(self) -> str:
149149
"""Return a string representation of the task."""
150150
repr_template = (
151-
'<{CLS}(id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
151+
'<{CLS}(id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
152152
)
153153
return repr_template.format(
154154
CLS=self.__class__.__name__,
155155
ID=self.id,
156-
CTX_ID=self.contextId,
156+
CTX_ID=self.context_id,
157157
STATUS=self.status,
158158
)
159159

@@ -188,11 +188,11 @@ class TaskModel(TaskMixin, base):
188188
@override
189189
def __repr__(self) -> str:
190190
"""Return a string representation of the task."""
191-
repr_template = '<TaskModel[{TABLE}](id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
191+
repr_template = '<TaskModel[{TABLE}](id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
192192
return repr_template.format(
193193
TABLE=table_name,
194194
ID=self.id,
195-
CTX_ID=self.contextId,
195+
CTX_ID=self.context_id,
196196
STATUS=self.status,
197197
)
198198

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def on_cancel_task(
126126

127127
task_manager = TaskManager(
128128
task_id=task.id,
129-
context_id=task.contextId,
129+
context_id=task.context_id,
130130
task_store=self.task_store,
131131
initial_message=None,
132132
)
@@ -140,7 +140,7 @@ async def on_cancel_task(
140140
RequestContext(
141141
None,
142142
task_id=task.id,
143-
context_id=task.contextId,
143+
context_id=task.context_id,
144144
task=task,
145145
),
146146
queue,
@@ -185,7 +185,7 @@ async def _setup_message_execution(
185185
# Create task manager and validate existing task
186186
task_manager = TaskManager(
187187
task_id=params.message.taskId,
188-
context_id=params.message.contextId,
188+
context_id=params.message.context_id,
189189
task_store=self.task_store,
190190
initial_message=params.message,
191191
)
@@ -205,7 +205,7 @@ async def _setup_message_execution(
205205
request_context = await self._request_context_builder.build(
206206
params=params,
207207
task_id=task.id if task else None,
208-
context_id=params.message.contextId,
208+
context_id=params.message.context_id,
209209
task=task,
210210
context=context,
211211
)
@@ -430,7 +430,7 @@ async def on_resubscribe_to_task(
430430

431431
task_manager = TaskManager(
432432
task_id=task.id,
433-
context_id=task.contextId,
433+
context_id=task.context_id,
434434
task_store=self.task_store,
435435
initial_message=None,
436436
)

src/a2a/server/tasks/database_task_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _to_orm(self, task: Task) -> TaskModel:
9595
"""Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
9696
return self.task_model(
9797
id=task.id,
98-
contextId=task.contextId,
98+
context_id=task.context_id,
9999
kind=task.kind,
100100
status=task.status,
101101
artifacts=task.artifacts,
@@ -108,7 +108,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
108108
# Map database columns to Pydantic model fields
109109
task_data_from_db = {
110110
'id': task_model.id,
111-
'contextId': task_model.contextId,
111+
'context_id': task_model.context_id,
112112
'kind': task_model.kind,
113113
'status': task_model.status,
114114
'artifacts': task_model.artifacts,

src/a2a/server/tasks/task_manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ async def save_task_event(
110110
)
111111
if not self.task_id:
112112
self.task_id = task_id_from_event
113-
if self.context_id and self.context_id != event.contextId:
113+
if self.context_id and self.context_id != event.context_id:
114114
raise ServerError(
115115
error=InvalidParamsError(
116-
message=f"Context in event doesn't match TaskManager {self.context_id} : {event.contextId}"
116+
message=f"Context in event doesn't match TaskManager {self.context_id} : {event.context_id}"
117117
)
118118
)
119119
if not self.context_id:
120-
self.context_id = event.contextId
120+
self.context_id = event.context_id
121121

122122
logger.debug(
123123
'Processing save of task event of type %s for task_id: %s',
@@ -173,11 +173,11 @@ async def ensure_task(
173173
logger.info(
174174
'Task not found or task_id not set. Creating new task for event (task_id: %s, context_id: %s).',
175175
event.taskId,
176-
event.contextId,
176+
event.context_id,
177177
)
178178
# streaming agent did not previously stream task object.
179179
# Create a task object with the available information and persist the event
180-
task = self._init_task_obj(event.taskId, event.contextId)
180+
task = self._init_task_obj(event.taskId, event.context_id)
181181
await self._save_task(task)
182182

183183
return task
@@ -219,7 +219,7 @@ def _init_task_obj(self, task_id: str, context_id: str) -> Task:
219219
history = [self._initial_message] if self._initial_message else []
220220
return Task(
221221
id=task_id,
222-
contextId=context_id,
222+
context_id=context_id,
223223
status=TaskStatus(state=TaskState.submitted),
224224
history=history,
225225
)
@@ -236,7 +236,7 @@ async def _save_task(self, task: Task) -> None:
236236
if not self.task_id:
237237
logger.info('New task created with id: %s', task.id)
238238
self.task_id = task.id
239-
self.context_id = task.contextId
239+
self.context_id = task.context_id
240240

241241
def update_with_message(self, message: Message, task: Task) -> Task:
242242
"""Updates a task object in memory by adding a new message to its history.

src/a2a/server/tasks/task_updater.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def update_status(
7575
await self.event_queue.enqueue_event(
7676
TaskStatusUpdateEvent(
7777
taskId=self.task_id,
78-
contextId=self.context_id,
78+
context_id=self.context_id,
7979
final=final,
8080
status=TaskStatus(
8181
state=state,
@@ -110,7 +110,7 @@ async def add_artifact( # noqa: PLR0913
110110
await self.event_queue.enqueue_event(
111111
TaskArtifactUpdateEvent(
112112
taskId=self.task_id,
113-
contextId=self.context_id,
113+
context_id=self.context_id,
114114
artifact=Artifact(
115115
artifactId=artifact_id,
116116
name=name,
@@ -198,7 +198,7 @@ def new_agent_message(
198198
return Message(
199199
role=Role.agent,
200200
taskId=self.task_id,
201-
contextId=self.context_id,
201+
context_id=self.context_id,
202202
messageId=str(uuid.uuid4()),
203203
metadata=metadata,
204204
parts=parts,

src/a2a/utils/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def create_task_obj(message_send_params: MessageSendParams) -> Task:
3636
Returns:
3737
A new `Task` object initialized with 'submitted' status and the input message in history.
3838
"""
39-
if not message_send_params.message.contextId:
40-
message_send_params.message.contextId = str(uuid4())
39+
if not message_send_params.message.context_id:
40+
message_send_params.message.context_id = str(uuid4())
4141

4242
return Task(
4343
id=str(uuid4()),
44-
contextId=message_send_params.message.contextId,
44+
context_id=message_send_params.message.context_id,
4545
status=TaskStatus(state=TaskState.submitted),
4646
history=[message_send_params.message],
4747
)

src/a2a/utils/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def new_agent_text_message(
3636
parts=[Part(root=TextPart(text=text))],
3737
messageId=str(uuid.uuid4()),
3838
taskId=task_id,
39-
contextId=context_id,
39+
context_id=context_id,
4040
)
4141

4242

@@ -60,7 +60,7 @@ def new_agent_parts_message(
6060
parts=parts,
6161
messageId=str(uuid.uuid4()),
6262
taskId=task_id,
63-
contextId=context_id,
63+
context_id=context_id,
6464
)
6565

6666

src/a2a/utils/proto_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
2828
return a2a_pb2.Message(
2929
message_id=message.messageId,
3030
content=[ToProto.part(p) for p in message.parts],
31-
context_id=message.contextId,
31+
context_id=message.context_id,
3232
task_id=message.taskId,
3333
role=cls.role(message.role),
3434
metadata=ToProto.metadata(message.metadata),
@@ -81,7 +81,7 @@ def file(
8181
def task(cls, task: types.Task) -> a2a_pb2.Task:
8282
return a2a_pb2.Task(
8383
id=task.id,
84-
context_id=task.contextId,
84+
context_id=task.context_id,
8585
status=ToProto.task_status(task.status),
8686
artifacts=(
8787
[ToProto.artifact(a) for a in task.artifacts]
@@ -161,7 +161,7 @@ def task_artifact_update_event(
161161
) -> a2a_pb2.TaskArtifactUpdateEvent:
162162
return a2a_pb2.TaskArtifactUpdateEvent(
163163
task_id=event.taskId,
164-
context_id=event.contextId,
164+
context_id=event.context_id,
165165
artifact=ToProto.artifact(event.artifact),
166166
metadata=ToProto.metadata(event.metadata),
167167
append=event.append or False,
@@ -174,7 +174,7 @@ def task_status_update_event(
174174
) -> a2a_pb2.TaskStatusUpdateEvent:
175175
return a2a_pb2.TaskStatusUpdateEvent(
176176
task_id=event.taskId,
177-
context_id=event.contextId,
177+
context_id=event.context_id,
178178
status=ToProto.task_status(event.status),
179179
metadata=ToProto.metadata(event.metadata),
180180
final=event.final,
@@ -438,7 +438,7 @@ def message(cls, message: a2a_pb2.Message) -> types.Message:
438438
return types.Message(
439439
messageId=message.message_id,
440440
parts=[FromProto.part(p) for p in message.content],
441-
contextId=message.context_id,
441+
context_id=message.context_id,
442442
taskId=message.task_id,
443443
role=FromProto.role(message.role),
444444
metadata=FromProto.metadata(message.metadata),
@@ -483,7 +483,7 @@ def file(
483483
def task(cls, task: a2a_pb2.Task) -> types.Task:
484484
return types.Task(
485485
id=task.id,
486-
contextId=task.context_id,
486+
context_id=task.context_id,
487487
status=FromProto.task_status(task.status),
488488
artifacts=[FromProto.artifact(a) for a in task.artifacts],
489489
history=[FromProto.message(h) for h in task.history],
@@ -530,7 +530,7 @@ def task_artifact_update_event(
530530
) -> types.TaskArtifactUpdateEvent:
531531
return types.TaskArtifactUpdateEvent(
532532
taskId=event.task_id,
533-
contextId=event.context_id,
533+
context_id=event.context_id,
534534
artifact=FromProto.artifact(event.artifact),
535535
metadata=FromProto.metadata(event.metadata),
536536
append=event.append,
@@ -543,7 +543,7 @@ def task_status_update_event(
543543
) -> types.TaskStatusUpdateEvent:
544544
return types.TaskStatusUpdateEvent(
545545
taskId=event.task_id,
546-
contextId=event.context_id,
546+
context_id=event.context_id,
547547
status=FromProto.task_status(event.status),
548548
metadata=FromProto.metadata(event.metadata),
549549
final=event.final,

src/a2a/utils/task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def new_task(request: Message) -> Task:
2828
return Task(
2929
status=TaskStatus(state=TaskState.submitted),
3030
id=(request.taskId if request.taskId else str(uuid.uuid4())),
31-
contextId=(
32-
request.contextId if request.contextId else str(uuid.uuid4())
31+
context_id=(
32+
request.context_id if request.context_id else str(uuid.uuid4())
3333
),
3434
history=[request],
3535
)
@@ -65,7 +65,7 @@ def completed_task(
6565
return Task(
6666
status=TaskStatus(state=TaskState.completed),
6767
id=task_id,
68-
contextId=context_id,
68+
context_id=context_id,
6969
artifacts=artifacts,
7070
history=history,
7171
)

tests/client/test_auth_middleware.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def build_success_response() -> dict:
5555
jsonrpc='2.0',
5656
result=Message(
5757
kind='message',
58-
messageId='message-id',
58+
message_id='message-id',
5959
role=Role.agent,
6060
parts=[],
6161
),
@@ -68,7 +68,7 @@ def build_send_message_request() -> SendMessageRequest:
6868
id='1',
6969
params=MessageSendParams(
7070
message=Message(
71-
messageId='msg1',
71+
message_id='msg1',
7272
role=Role.user,
7373
parts=[],
7474
)
@@ -223,9 +223,9 @@ class AuthTestCase:
223223
security_scheme=OAuth2SecurityScheme(
224224
type='oauth2',
225225
flows=OAuthFlows(
226-
authorizationCode=AuthorizationCodeOAuthFlow(
227-
authorizationUrl='http://provider.com/auth',
228-
tokenUrl='http://provider.com/token',
226+
authorization_code=AuthorizationCodeOAuthFlow(
227+
authorization_url='http://provider.com/auth',
228+
token_url='http://provider.com/token',
229229
scopes={'read': 'Read scope'},
230230
)
231231
),
@@ -242,7 +242,7 @@ class AuthTestCase:
242242
credential='secret-oidc-id-token',
243243
security_scheme=OpenIdConnectSecurityScheme(
244244
type='openIdConnect',
245-
openIdConnectUrl='http://provider.com/.well-known/openid-configuration',
245+
open_id_connect_url='http://provider.com/.well-known/openid-configuration',
246246
),
247247
expected_header_key='Authorization',
248248
expected_header_value_func=lambda c: f'Bearer {c}',
@@ -282,12 +282,12 @@ async def test_auth_interceptor_variants(test_case, store):
282282
name=f'{test_case.scheme_name}bot',
283283
description=f'A bot that uses {test_case.scheme_name}',
284284
version='1.0',
285-
defaultInputModes=[],
286-
defaultOutputModes=[],
285+
default_input_modes=[],
286+
default_output_modes=[],
287287
skills=[],
288288
capabilities=AgentCapabilities(),
289289
security=[{test_case.scheme_name: []}],
290-
securitySchemes={
290+
security_schemes={
291291
test_case.scheme_name: SecurityScheme(
292292
root=test_case.security_scheme
293293
)
@@ -314,7 +314,7 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
314314
):
315315
"""
316316
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
317-
but not defined in securitySchemes.
317+
but not defined in security_schemes.
318318
"""
319319
scheme_name = 'missing'
320320
session_id = 'session-id'
@@ -328,12 +328,12 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
328328
name='missingbot',
329329
description='A bot that uses missing scheme definition',
330330
version='1.0',
331-
defaultInputModes=[],
332-
defaultOutputModes=[],
331+
default_input_modes=[],
332+
default_output_modes=[],
333333
skills=[],
334334
capabilities=AgentCapabilities(),
335335
security=[{scheme_name: []}],
336-
securitySchemes={},
336+
security_schemes={},
337337
)
338338

339339
new_payload, new_kwargs = await auth_interceptor.intercept(

0 commit comments

Comments
 (0)