Skip to content

Commit d57a83b

Browse files
authored
[ENHANCEMENT] argilla: link user responses and suggestions to record (#5518)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR reviews and exposes the record resources at responses, user responses, and suggestions level. Allowing users to access the corresponding record from responses or suggestions. This is useful when working with webhooks. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 5875d59 commit d57a83b

File tree

4 files changed

+72
-28
lines changed

4 files changed

+72
-28
lines changed

argilla/src/argilla/records/_resource.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,19 +262,17 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
262262
fields=model.fields,
263263
metadata={meta.name: meta.value for meta in model.metadata},
264264
vectors={vector.name: vector.vector_values for vector in model.vectors},
265-
# Responses and their models are not aligned 1-1.
266-
responses=[
267-
response
268-
for response_model in model.responses
269-
for response in UserResponse.from_model(response_model, dataset=dataset)
270-
],
271-
suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions],
265+
_dataset=dataset,
266+
responses=[],
267+
suggestions=[],
272268
)
273269

274270
# set private attributes
275-
instance._dataset = dataset
276271
instance._model.id = model.id
277272
instance._model.status = model.status
273+
# Responses and suggestions are computed separately based on the record model
274+
instance.responses.from_models(model.responses)
275+
instance.suggestions.from_models(model.suggestions)
278276

279277
return instance
280278

@@ -349,11 +347,10 @@ class RecordResponses(Iterable[Response]):
349347
def __init__(self, responses: List[Response], record: Record) -> None:
350348
self.record = record
351349
self.__responses_by_question_name = defaultdict(list)
350+
self.__responses = []
352351

353-
self.__responses = responses or []
354-
for response in self.__responses:
355-
response.record = self.record
356-
self.__responses_by_question_name[response.question_name].append(response)
352+
for response in responses or []:
353+
self.add(response)
357354

358355
def __iter__(self):
359356
return iter(self.__responses)
@@ -409,6 +406,11 @@ def _check_response_already_exists(self, response: Response) -> None:
409406
f"already found. The responses for the same question name do not support more than one user"
410407
)
411408

409+
def from_models(self, responses: List[UserResponseModel]) -> None:
410+
for response_model in responses:
411+
for response in UserResponse.from_model(response_model, record=self.record):
412+
self.add(response)
413+
412414

413415
class RecordSuggestions(Iterable[Suggestion]):
414416
"""This is a container class for the suggestions of a Record.
@@ -461,3 +463,7 @@ def add(self, suggestion: Suggestion) -> None:
461463
"""
462464
suggestion.record = self.record
463465
self._suggestion_by_question_name[suggestion.question_name] = suggestion
466+
467+
def from_models(self, suggestions: List[SuggestionModel]) -> None:
468+
for suggestion_model in suggestions:
469+
self.add(Suggestion.from_model(suggestion_model, record=self.record))

argilla/src/argilla/responses.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from argilla.settings import RankingQuestion
2424

2525
if TYPE_CHECKING:
26-
from argilla import Argilla, Dataset, Record
26+
from argilla import Argilla, Record
2727

2828
__all__ = ["Response", "UserResponse", "ResponseStatus"]
2929

@@ -71,12 +71,22 @@ def __init__(
7171
if isinstance(status, str):
7272
status = ResponseStatus(status)
7373

74-
self.record = _record
74+
self._record = _record
7575
self.question_name = question_name
7676
self.value = value
7777
self.user_id = user_id
7878
self.status = status
7979

80+
@property
81+
def record(self) -> "Record":
82+
"""Returns the record associated with the response"""
83+
return self._record
84+
85+
@record.setter
86+
def record(self, record: "Record") -> None:
87+
"""Sets the record associated with the response"""
88+
self._record = record
89+
8090
def serialize(self) -> dict[str, Any]:
8191
"""Serializes the Response to a dictionary. This is principally used for sending the response to the API, \
8292
but can be used for data wrangling or manual export.
@@ -138,6 +148,9 @@ def __init__(
138148
user_id=self._compute_user_id_from_responses(responses),
139149
)
140150

151+
for response in responses:
152+
response.record = _record
153+
141154
def __iter__(self) -> Iterable[Response]:
142155
return iter(self.responses)
143156

@@ -164,19 +177,29 @@ def user_id(self, user_id: UUID) -> None:
164177
@property
165178
def responses(self) -> List[Response]:
166179
"""Returns the list of responses"""
167-
return self.__model_as_responses_list(self._model)
180+
return self.__model_as_responses_list(self._model, record=self._record)
181+
182+
@property
183+
def record(self) -> "Record":
184+
"""Returns the record associated with the response"""
185+
return self._record
186+
187+
@record.setter
188+
def record(self, record: "Record") -> None:
189+
"""Sets the record associated with the response"""
190+
self._record = record
168191

169192
@classmethod
170-
def from_model(cls, model: UserResponseModel, dataset: "Dataset") -> "UserResponse":
193+
def from_model(cls, model: UserResponseModel, record: "Record") -> "UserResponse":
171194
"""Creates a UserResponse from a ResponseModel"""
172-
responses = cls.__model_as_responses_list(model)
195+
responses = cls.__model_as_responses_list(model, record=record)
173196
for response in responses:
174-
question = dataset.settings.questions[response.question_name]
197+
question = record.dataset.settings.questions[response.question_name]
175198
# We need to adapt the ranking question value to the expected format
176199
if isinstance(question, RankingQuestion):
177200
response.value = cls.__ranking_from_model_value(response.value) # type: ignore
178201

179-
return cls(responses=responses)
202+
return cls(responses=responses, _record=record)
180203

181204
def api_model(self):
182205
"""Returns the model that is used to interact with the API"""
@@ -223,7 +246,7 @@ def __responses_as_model_values(responses: List[Response]) -> Dict[str, Dict[str
223246
return {answer.question_name: {"value": answer.value} for answer in responses}
224247

225248
@classmethod
226-
def __model_as_responses_list(cls, model: UserResponseModel) -> List[Response]:
249+
def __model_as_responses_list(cls, model: UserResponseModel, record: "Record") -> List[Response]:
227250
"""Creates a list of Responses from a UserResponseModel without changing the format of the values"""
228251

229252
return [
@@ -232,6 +255,7 @@ def __model_as_responses_list(cls, model: UserResponseModel) -> List[Response]:
232255
value=value["value"],
233256
user_id=model.user_id,
234257
status=model.status,
258+
_record=record,
235259
)
236260
for question_name, value in model.values.items()
237261
]

argilla/src/argilla/suggestions.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from argilla.settings import RankingQuestion
2020

2121
if TYPE_CHECKING:
22-
from argilla import Dataset, QuestionType, Record
22+
from argilla import QuestionType, Record
2323

2424
__all__ = ["Suggestion"]
2525

@@ -54,7 +54,7 @@ def __init__(
5454
if value is None:
5555
raise ValueError("value is required")
5656

57-
self.record = _record
57+
self._record = _record
5858
self._model = SuggestionModel(
5959
question_name=question_name,
6060
value=value,
@@ -104,13 +104,22 @@ def agent(self) -> Optional[str]:
104104
def agent(self, value: str) -> None:
105105
self._model.agent = value
106106

107+
@property
108+
def record(self) -> Optional["Record"]:
109+
"""The record that the suggestion is for."""
110+
return self._record
111+
112+
@record.setter
113+
def record(self, value: "Record") -> None:
114+
self._record = value
115+
107116
@classmethod
108-
def from_model(cls, model: SuggestionModel, dataset: "Dataset") -> "Suggestion":
109-
question = dataset.settings.questions[model.question_id]
117+
def from_model(cls, model: SuggestionModel, record: "Record") -> "Suggestion":
118+
question = record.dataset.settings.questions[model.question_id]
110119
model.question_name = question.name
111120
model.value = cls.__from_model_value(model.value, question)
112121

113-
instance = cls(question.name, model.value)
122+
instance = cls(question.name, model.value, _record=record)
114123
instance._model = model
115124

116125
return instance

argilla/tests/unit/test_resources/test_responses.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818

19-
from argilla import UserResponse, Response, Dataset, Workspace
19+
from argilla import UserResponse, Response, Dataset, Workspace, Record
2020
from argilla._models import UserResponseModel, ResponseStatus
2121

2222

@@ -89,9 +89,14 @@ def test_create_user_response_with_multiple_user_id(self):
8989

9090
def test_create_user_response_from_draft_response_model_without_values(self):
9191
model = UserResponseModel(values={}, status=ResponseStatus.draft, user=uuid.uuid4())
92-
response = UserResponse.from_model(
93-
model=model, dataset=Dataset(name="burr", workspace=Workspace(name="test", id=uuid.uuid4()))
92+
93+
record = Record(
94+
fields={"question": "answer"},
95+
_dataset=Dataset(name="burr", workspace=Workspace(name="test", id=uuid.uuid4())),
9496
)
97+
98+
response = UserResponse.from_model(model=model, record=record)
99+
95100
assert len(response.responses) == 0
96101
assert response.user_id is None
97102
assert response.status == ResponseStatus.draft

0 commit comments

Comments
 (0)